dax: add support for fsync/sync
[cascardo/linux.git] / fs / dax.c
index 43671b6..d5f6aca 100644 (file)
--- a/fs/dax.c
+++ b/fs/dax.c
 #include <linux/memcontrol.h>
 #include <linux/mm.h>
 #include <linux/mutex.h>
+#include <linux/pagevec.h>
 #include <linux/pmem.h>
 #include <linux/sched.h>
 #include <linux/uio.h>
 #include <linux/vmstat.h>
+#include <linux/pfn_t.h>
+#include <linux/sizes.h>
+
+static long dax_map_atomic(struct block_device *bdev, struct blk_dax_ctl *dax)
+{
+       struct request_queue *q = bdev->bd_queue;
+       long rc = -EIO;
+
+       dax->addr = (void __pmem *) ERR_PTR(-EIO);
+       if (blk_queue_enter(q, true) != 0)
+               return rc;
+
+       rc = bdev_direct_access(bdev, dax);
+       if (rc < 0) {
+               dax->addr = (void __pmem *) ERR_PTR(rc);
+               blk_queue_exit(q);
+               return rc;
+       }
+       return rc;
+}
+
+static void dax_unmap_atomic(struct block_device *bdev,
+               const struct blk_dax_ctl *dax)
+{
+       if (IS_ERR(dax->addr))
+               return;
+       blk_queue_exit(bdev->bd_queue);
+}
 
 /*
  * dax_clear_blocks() is called from within transaction context from XFS,
  * and hence this means the stack from this point must follow GFP_NOFS
  * semantics for all operations.
  */
-int dax_clear_blocks(struct inode *inode, sector_t block, long size)
+int dax_clear_blocks(struct inode *inode, sector_t block, long _size)
 {
        struct block_device *bdev = inode->i_sb->s_bdev;
-       sector_t sector = block << (inode->i_blkbits - 9);
+       struct blk_dax_ctl dax = {
+               .sector = block << (inode->i_blkbits - 9),
+               .size = _size,
+       };
 
        might_sleep();
        do {
-               void __pmem *addr;
-               unsigned long pfn;
-               long count;
+               long count, sz;
 
-               count = bdev_direct_access(bdev, sector, &addr, &pfn, size);
+               count = dax_map_atomic(bdev, &dax);
                if (count < 0)
                        return count;
-               BUG_ON(size < count);
-               while (count > 0) {
-                       unsigned pgsz = PAGE_SIZE - offset_in_page(addr);
-                       if (pgsz > count)
-                               pgsz = count;
-                       clear_pmem(addr, pgsz);
-                       addr += pgsz;
-                       size -= pgsz;
-                       count -= pgsz;
-                       BUG_ON(pgsz & 511);
-                       sector += pgsz / 512;
-                       cond_resched();
-               }
-       } while (size);
+               sz = min_t(long, count, SZ_128K);
+               clear_pmem(dax.addr, sz);
+               dax.size -= sz;
+               dax.sector += sz / 512;
+               dax_unmap_atomic(bdev, &dax);
+               cond_resched();
+       } while (dax.size);
 
        wmb_pmem();
        return 0;
 }
 EXPORT_SYMBOL_GPL(dax_clear_blocks);
 
-static long dax_get_addr(struct buffer_head *bh, void __pmem **addr,
-               unsigned blkbits)
-{
-       unsigned long pfn;
-       sector_t sector = bh->b_blocknr << (blkbits - 9);
-       return bdev_direct_access(bh->b_bdev, sector, addr, &pfn, bh->b_size);
-}
-
 /* the clear_pmem() calls are ordered by a wmb_pmem() in the caller */
 static void dax_new_buf(void __pmem *addr, unsigned size, unsigned first,
                loff_t pos, loff_t end)
@@ -105,19 +120,29 @@ static bool buffer_size_valid(struct buffer_head *bh)
        return bh->b_state != 0;
 }
 
+
+static sector_t to_sector(const struct buffer_head *bh,
+               const struct inode *inode)
+{
+       sector_t sector = bh->b_blocknr << (inode->i_blkbits - 9);
+
+       return sector;
+}
+
 static ssize_t dax_io(struct inode *inode, struct iov_iter *iter,
                      loff_t start, loff_t end, get_block_t get_block,
                      struct buffer_head *bh)
 {
-       ssize_t retval = 0;
-       loff_t pos = start;
-       loff_t max = start;
-       loff_t bh_max = start;
-       void __pmem *addr;
-       bool hole = false;
-       bool need_wmb = false;
-
-       if (iov_iter_rw(iter) != WRITE)
+       loff_t pos = start, max = start, bh_max = start;
+       bool hole = false, need_wmb = false;
+       struct block_device *bdev = NULL;
+       int rw = iov_iter_rw(iter), rc;
+       long map_len = 0;
+       struct blk_dax_ctl dax = {
+               .addr = (void __pmem *) ERR_PTR(-EIO),
+       };
+
+       if (rw == READ)
                end = min(end, i_size_read(inode));
 
        while (pos < end) {
@@ -132,13 +157,13 @@ static ssize_t dax_io(struct inode *inode, struct iov_iter *iter,
                        if (pos == bh_max) {
                                bh->b_size = PAGE_ALIGN(end - pos);
                                bh->b_state = 0;
-                               retval = get_block(inode, block, bh,
-                                                  iov_iter_rw(iter) == WRITE);
-                               if (retval)
+                               rc = get_block(inode, block, bh, rw == WRITE);
+                               if (rc)
                                        break;
                                if (!buffer_size_valid(bh))
                                        bh->b_size = 1 << blkbits;
                                bh_max = pos - first + bh->b_size;
+                               bdev = bh->b_bdev;
                        } else {
                                unsigned done = bh->b_size -
                                                (bh_max - (pos - first));
@@ -146,47 +171,53 @@ static ssize_t dax_io(struct inode *inode, struct iov_iter *iter,
                                bh->b_size -= done;
                        }
 
-                       hole = iov_iter_rw(iter) != WRITE && !buffer_written(bh);
+                       hole = rw == READ && !buffer_written(bh);
                        if (hole) {
-                               addr = NULL;
                                size = bh->b_size - first;
                        } else {
-                               retval = dax_get_addr(bh, &addr, blkbits);
-                               if (retval < 0)
+                               dax_unmap_atomic(bdev, &dax);
+                               dax.sector = to_sector(bh, inode);
+                               dax.size = bh->b_size;
+                               map_len = dax_map_atomic(bdev, &dax);
+                               if (map_len < 0) {
+                                       rc = map_len;
                                        break;
+                               }
                                if (buffer_unwritten(bh) || buffer_new(bh)) {
-                                       dax_new_buf(addr, retval, first, pos,
-                                                                       end);
+                                       dax_new_buf(dax.addr, map_len, first,
+                                                       pos, end);
                                        need_wmb = true;
                                }
-                               addr += first;
-                               size = retval - first;
+                               dax.addr += first;
+                               size = map_len - first;
                        }
                        max = min(pos + size, end);
                }
 
                if (iov_iter_rw(iter) == WRITE) {
-                       len = copy_from_iter_pmem(addr, max - pos, iter);
+                       len = copy_from_iter_pmem(dax.addr, max - pos, iter);
                        need_wmb = true;
                } else if (!hole)
-                       len = copy_to_iter((void __force *)addr, max - pos,
+                       len = copy_to_iter((void __force *) dax.addr, max - pos,
                                        iter);
                else
                        len = iov_iter_zero(max - pos, iter);
 
                if (!len) {
-                       retval = -EFAULT;
+                       rc = -EFAULT;
                        break;
                }
 
                pos += len;
-               addr += len;
+               if (!IS_ERR(dax.addr))
+                       dax.addr += len;
        }
 
        if (need_wmb)
                wmb_pmem();
+       dax_unmap_atomic(bdev, &dax);
 
-       return (pos == start) ? retval : pos - start;
+       return (pos == start) ? rc : pos - start;
 }
 
 /**
@@ -275,28 +306,228 @@ static int dax_load_hole(struct address_space *mapping, struct page *page,
        return VM_FAULT_LOCKED;
 }
 
-static int copy_user_bh(struct page *to, struct buffer_head *bh,
-                       unsigned blkbits, unsigned long vaddr)
+static int copy_user_bh(struct page *to, struct inode *inode,
+               struct buffer_head *bh, unsigned long vaddr)
 {
-       void __pmem *vfrom;
+       struct blk_dax_ctl dax = {
+               .sector = to_sector(bh, inode),
+               .size = bh->b_size,
+       };
+       struct block_device *bdev = bh->b_bdev;
        void *vto;
 
-       if (dax_get_addr(bh, &vfrom, blkbits) < 0)
-               return -EIO;
+       if (dax_map_atomic(bdev, &dax) < 0)
+               return PTR_ERR(dax.addr);
        vto = kmap_atomic(to);
-       copy_user_page(vto, (void __force *)vfrom, vaddr, to);
+       copy_user_page(vto, (void __force *)dax.addr, vaddr, to);
        kunmap_atomic(vto);
+       dax_unmap_atomic(bdev, &dax);
        return 0;
 }
 
+#define NO_SECTOR -1
+#define DAX_PMD_INDEX(page_index) (page_index & (PMD_MASK >> PAGE_CACHE_SHIFT))
+
+static int dax_radix_entry(struct address_space *mapping, pgoff_t index,
+               sector_t sector, bool pmd_entry, bool dirty)
+{
+       struct radix_tree_root *page_tree = &mapping->page_tree;
+       pgoff_t pmd_index = DAX_PMD_INDEX(index);
+       int type, error = 0;
+       void *entry;
+
+       WARN_ON_ONCE(pmd_entry && !dirty);
+       __mark_inode_dirty(mapping->host, I_DIRTY_PAGES);
+
+       spin_lock_irq(&mapping->tree_lock);
+
+       entry = radix_tree_lookup(page_tree, pmd_index);
+       if (entry && RADIX_DAX_TYPE(entry) == RADIX_DAX_PMD) {
+               index = pmd_index;
+               goto dirty;
+       }
+
+       entry = radix_tree_lookup(page_tree, index);
+       if (entry) {
+               type = RADIX_DAX_TYPE(entry);
+               if (WARN_ON_ONCE(type != RADIX_DAX_PTE &&
+                                       type != RADIX_DAX_PMD)) {
+                       error = -EIO;
+                       goto unlock;
+               }
+
+               if (!pmd_entry || type == RADIX_DAX_PMD)
+                       goto dirty;
+
+               /*
+                * We only insert dirty PMD entries into the radix tree.  This
+                * means we don't need to worry about removing a dirty PTE
+                * entry and inserting a clean PMD entry, thus reducing the
+                * range we would flush with a follow-up fsync/msync call.
+                */
+               radix_tree_delete(&mapping->page_tree, index);
+               mapping->nrexceptional--;
+       }
+
+       if (sector == NO_SECTOR) {
+               /*
+                * This can happen during correct operation if our pfn_mkwrite
+                * fault raced against a hole punch operation.  If this
+                * happens the pte that was hole punched will have been
+                * unmapped and the radix tree entry will have been removed by
+                * the time we are called, but the call will still happen.  We
+                * will return all the way up to wp_pfn_shared(), where the
+                * pte_same() check will fail, eventually causing page fault
+                * to be retried by the CPU.
+                */
+               goto unlock;
+       }
+
+       error = radix_tree_insert(page_tree, index,
+                       RADIX_DAX_ENTRY(sector, pmd_entry));
+       if (error)
+               goto unlock;
+
+       mapping->nrexceptional++;
+ dirty:
+       if (dirty)
+               radix_tree_tag_set(page_tree, index, PAGECACHE_TAG_DIRTY);
+ unlock:
+       spin_unlock_irq(&mapping->tree_lock);
+       return error;
+}
+
+static int dax_writeback_one(struct block_device *bdev,
+               struct address_space *mapping, pgoff_t index, void *entry)
+{
+       struct radix_tree_root *page_tree = &mapping->page_tree;
+       int type = RADIX_DAX_TYPE(entry);
+       struct radix_tree_node *node;
+       struct blk_dax_ctl dax;
+       void **slot;
+       int ret = 0;
+
+       spin_lock_irq(&mapping->tree_lock);
+       /*
+        * Regular page slots are stabilized by the page lock even
+        * without the tree itself locked.  These unlocked entries
+        * need verification under the tree lock.
+        */
+       if (!__radix_tree_lookup(page_tree, index, &node, &slot))
+               goto unlock;
+       if (*slot != entry)
+               goto unlock;
+
+       /* another fsync thread may have already written back this entry */
+       if (!radix_tree_tag_get(page_tree, index, PAGECACHE_TAG_TOWRITE))
+               goto unlock;
+
+       if (WARN_ON_ONCE(type != RADIX_DAX_PTE && type != RADIX_DAX_PMD)) {
+               ret = -EIO;
+               goto unlock;
+       }
+
+       dax.sector = RADIX_DAX_SECTOR(entry);
+       dax.size = (type == RADIX_DAX_PMD ? PMD_SIZE : PAGE_SIZE);
+       spin_unlock_irq(&mapping->tree_lock);
+
+       /*
+        * We cannot hold tree_lock while calling dax_map_atomic() because it
+        * eventually calls cond_resched().
+        */
+       ret = dax_map_atomic(bdev, &dax);
+       if (ret < 0)
+               return ret;
+
+       if (WARN_ON_ONCE(ret < dax.size)) {
+               ret = -EIO;
+               goto unmap;
+       }
+
+       wb_cache_pmem(dax.addr, dax.size);
+
+       spin_lock_irq(&mapping->tree_lock);
+       radix_tree_tag_clear(page_tree, index, PAGECACHE_TAG_TOWRITE);
+       spin_unlock_irq(&mapping->tree_lock);
+ unmap:
+       dax_unmap_atomic(bdev, &dax);
+       return ret;
+
+ unlock:
+       spin_unlock_irq(&mapping->tree_lock);
+       return ret;
+}
+
+/*
+ * Flush the mapping to the persistent domain within the byte range of [start,
+ * end]. This is required by data integrity operations to ensure file data is
+ * on persistent storage prior to completion of the operation.
+ */
+int dax_writeback_mapping_range(struct address_space *mapping, loff_t start,
+               loff_t end)
+{
+       struct inode *inode = mapping->host;
+       struct block_device *bdev = inode->i_sb->s_bdev;
+       pgoff_t start_index, end_index, pmd_index;
+       pgoff_t indices[PAGEVEC_SIZE];
+       struct pagevec pvec;
+       bool done = false;
+       int i, ret = 0;
+       void *entry;
+
+       if (WARN_ON_ONCE(inode->i_blkbits != PAGE_SHIFT))
+               return -EIO;
+
+       start_index = start >> PAGE_CACHE_SHIFT;
+       end_index = end >> PAGE_CACHE_SHIFT;
+       pmd_index = DAX_PMD_INDEX(start_index);
+
+       rcu_read_lock();
+       entry = radix_tree_lookup(&mapping->page_tree, pmd_index);
+       rcu_read_unlock();
+
+       /* see if the start of our range is covered by a PMD entry */
+       if (entry && RADIX_DAX_TYPE(entry) == RADIX_DAX_PMD)
+               start_index = pmd_index;
+
+       tag_pages_for_writeback(mapping, start_index, end_index);
+
+       pagevec_init(&pvec, 0);
+       while (!done) {
+               pvec.nr = find_get_entries_tag(mapping, start_index,
+                               PAGECACHE_TAG_TOWRITE, PAGEVEC_SIZE,
+                               pvec.pages, indices);
+
+               if (pvec.nr == 0)
+                       break;
+
+               for (i = 0; i < pvec.nr; i++) {
+                       if (indices[i] > end_index) {
+                               done = true;
+                               break;
+                       }
+
+                       ret = dax_writeback_one(bdev, mapping, indices[i],
+                                       pvec.pages[i]);
+                       if (ret < 0)
+                               return ret;
+               }
+       }
+       wmb_pmem();
+       return 0;
+}
+EXPORT_SYMBOL_GPL(dax_writeback_mapping_range);
+
 static int dax_insert_mapping(struct inode *inode, struct buffer_head *bh,
                        struct vm_area_struct *vma, struct vm_fault *vmf)
 {
-       struct address_space *mapping = inode->i_mapping;
-       sector_t sector = bh->b_blocknr << (inode->i_blkbits - 9);
        unsigned long vaddr = (unsigned long)vmf->virtual_address;
-       void __pmem *addr;
-       unsigned long pfn;
+       struct address_space *mapping = inode->i_mapping;
+       struct block_device *bdev = bh->b_bdev;
+       struct blk_dax_ctl dax = {
+               .sector = to_sector(bh, inode),
+               .size = bh->b_size,
+       };
        pgoff_t size;
        int error;
 
@@ -315,20 +546,23 @@ static int dax_insert_mapping(struct inode *inode, struct buffer_head *bh,
                goto out;
        }
 
-       error = bdev_direct_access(bh->b_bdev, sector, &addr, &pfn, bh->b_size);
-       if (error < 0)
-               goto out;
-       if (error < PAGE_SIZE) {
-               error = -EIO;
+       if (dax_map_atomic(bdev, &dax) < 0) {
+               error = PTR_ERR(dax.addr);
                goto out;
        }
 
        if (buffer_unwritten(bh) || buffer_new(bh)) {
-               clear_pmem(addr, PAGE_SIZE);
+               clear_pmem(dax.addr, PAGE_SIZE);
                wmb_pmem();
        }
+       dax_unmap_atomic(bdev, &dax);
+
+       error = dax_radix_entry(mapping, vmf->pgoff, dax.sector, false,
+                       vmf->flags & FAULT_FLAG_WRITE);
+       if (error)
+               goto out;
 
-       error = vm_insert_mixed(vma, vaddr, pfn);
+       error = vm_insert_mixed(vma, vaddr, dax.pfn);
 
  out:
        i_mmap_unlock_read(mapping);
@@ -422,7 +656,7 @@ int __dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
        if (vmf->cow_page) {
                struct page *new_page = vmf->cow_page;
                if (buffer_written(&bh))
-                       error = copy_user_bh(new_page, &bh, blkbits, vaddr);
+                       error = copy_user_bh(new_page, inode, &bh, vaddr);
                else
                        clear_user_highpage(new_page, vaddr);
                if (error)
@@ -452,6 +686,7 @@ int __dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
                delete_from_page_cache(page);
                unlock_page(page);
                page_cache_release(page);
+               page = NULL;
        }
 
        /*
@@ -523,6 +758,24 @@ EXPORT_SYMBOL_GPL(dax_fault);
  */
 #define PG_PMD_COLOUR  ((PMD_SIZE >> PAGE_SHIFT) - 1)
 
+static void __dax_dbg(struct buffer_head *bh, unsigned long address,
+               const char *reason, const char *fn)
+{
+       if (bh) {
+               char bname[BDEVNAME_SIZE];
+               bdevname(bh->b_bdev, bname);
+               pr_debug("%s: %s addr: %lx dev %s state %lx start %lld "
+                       "length %zd fallback: %s\n", fn, current->comm,
+                       address, bname, bh->b_state, (u64)bh->b_blocknr,
+                       bh->b_size, reason);
+       } else {
+               pr_debug("%s: %s addr: %lx fallback: %s\n", fn,
+                       current->comm, address, reason);
+       }
+}
+
+#define dax_pmd_dbg(bh, address, reason)       __dax_dbg(bh, address, reason, "dax_pmd")
+
 int __dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
                pmd_t *pmd, unsigned int flags, get_block_t get_block,
                dax_iodone_t complete_unwritten)
@@ -534,61 +787,83 @@ int __dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
        unsigned blkbits = inode->i_blkbits;
        unsigned long pmd_addr = address & PMD_MASK;
        bool write = flags & FAULT_FLAG_WRITE;
-       long length;
-       void __pmem *kaddr;
+       struct block_device *bdev;
        pgoff_t size, pgoff;
-       sector_t block, sector;
-       unsigned long pfn;
-       int result = 0;
+       sector_t block;
+       int error, result = 0;
+       bool alloc = false;
 
-       /* dax pmd mappings are broken wrt gup and fork */
+       /* dax pmd mappings require pfn_t_devmap() */
        if (!IS_ENABLED(CONFIG_FS_DAX_PMD))
                return VM_FAULT_FALLBACK;
 
        /* Fall back to PTEs if we're going to COW */
-       if (write && !(vma->vm_flags & VM_SHARED))
+       if (write && !(vma->vm_flags & VM_SHARED)) {
+               split_huge_pmd(vma, pmd, address);
+               dax_pmd_dbg(NULL, address, "cow write");
                return VM_FAULT_FALLBACK;
+       }
        /* If the PMD would extend outside the VMA */
-       if (pmd_addr < vma->vm_start)
+       if (pmd_addr < vma->vm_start) {
+               dax_pmd_dbg(NULL, address, "vma start unaligned");
                return VM_FAULT_FALLBACK;
-       if ((pmd_addr + PMD_SIZE) > vma->vm_end)
+       }
+       if ((pmd_addr + PMD_SIZE) > vma->vm_end) {
+               dax_pmd_dbg(NULL, address, "vma end unaligned");
                return VM_FAULT_FALLBACK;
+       }
 
        pgoff = linear_page_index(vma, pmd_addr);
        size = (i_size_read(inode) + PAGE_SIZE - 1) >> PAGE_SHIFT;
        if (pgoff >= size)
                return VM_FAULT_SIGBUS;
        /* If the PMD would cover blocks out of the file */
-       if ((pgoff | PG_PMD_COLOUR) >= size)
+       if ((pgoff | PG_PMD_COLOUR) >= size) {
+               dax_pmd_dbg(NULL, address,
+                               "offset + huge page size > file size");
                return VM_FAULT_FALLBACK;
+       }
 
        memset(&bh, 0, sizeof(bh));
+       bh.b_bdev = inode->i_sb->s_bdev;
        block = (sector_t)pgoff << (PAGE_SHIFT - blkbits);
 
        bh.b_size = PMD_SIZE;
-       length = get_block(inode, block, &bh, write);
-       if (length)
+
+       if (get_block(inode, block, &bh, 0) != 0)
                return VM_FAULT_SIGBUS;
-       i_mmap_lock_read(mapping);
+
+       if (!buffer_mapped(&bh) && write) {
+               if (get_block(inode, block, &bh, 1) != 0)
+                       return VM_FAULT_SIGBUS;
+               alloc = true;
+       }
+
+       bdev = bh.b_bdev;
 
        /*
         * If the filesystem isn't willing to tell us the length of a hole,
         * just fall back to PTEs.  Calling get_block 512 times in a loop
         * would be silly.
         */
-       if (!buffer_size_valid(&bh) || bh.b_size < PMD_SIZE)
-               goto fallback;
+       if (!buffer_size_valid(&bh) || bh.b_size < PMD_SIZE) {
+               dax_pmd_dbg(&bh, address, "allocated block too small");
+               return VM_FAULT_FALLBACK;
+       }
 
        /*
         * If we allocated new storage, make sure no process has any
         * zero pages covering this hole
         */
-       if (buffer_new(&bh)) {
-               i_mmap_unlock_read(mapping);
-               unmap_mapping_range(mapping, pgoff << PAGE_SHIFT, PMD_SIZE, 0);
-               i_mmap_lock_read(mapping);
+       if (alloc) {
+               loff_t lstart = pgoff << PAGE_SHIFT;
+               loff_t lend = lstart + PMD_SIZE - 1; /* inclusive */
+
+               truncate_pagecache_range(inode, lstart, lend);
        }
 
+       i_mmap_lock_read(mapping);
+
        /*
         * If a truncate happened while we were allocating blocks, we may
         * leave blocks allocated to the file that are beyond EOF.  We can't
@@ -600,57 +875,108 @@ int __dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
                result = VM_FAULT_SIGBUS;
                goto out;
        }
-       if ((pgoff | PG_PMD_COLOUR) >= size)
+       if ((pgoff | PG_PMD_COLOUR) >= size) {
+               dax_pmd_dbg(&bh, address,
+                               "offset + huge page size > file size");
                goto fallback;
+       }
 
        if (!write && !buffer_mapped(&bh) && buffer_uptodate(&bh)) {
                spinlock_t *ptl;
                pmd_t entry;
                struct page *zero_page = get_huge_zero_page();
 
-               if (unlikely(!zero_page))
+               if (unlikely(!zero_page)) {
+                       dax_pmd_dbg(&bh, address, "no zero page");
                        goto fallback;
+               }
 
                ptl = pmd_lock(vma->vm_mm, pmd);
                if (!pmd_none(*pmd)) {
                        spin_unlock(ptl);
+                       dax_pmd_dbg(&bh, address, "pmd already present");
                        goto fallback;
                }
 
+               dev_dbg(part_to_dev(bdev->bd_part),
+                               "%s: %s addr: %lx pfn: <zero> sect: %llx\n",
+                               __func__, current->comm, address,
+                               (unsigned long long) to_sector(&bh, inode));
+
                entry = mk_pmd(zero_page, vma->vm_page_prot);
                entry = pmd_mkhuge(entry);
                set_pmd_at(vma->vm_mm, pmd_addr, pmd, entry);
                result = VM_FAULT_NOPAGE;
                spin_unlock(ptl);
        } else {
-               sector = bh.b_blocknr << (blkbits - 9);
-               length = bdev_direct_access(bh.b_bdev, sector, &kaddr, &pfn,
-                                               bh.b_size);
+               struct blk_dax_ctl dax = {
+                       .sector = to_sector(&bh, inode),
+                       .size = PMD_SIZE,
+               };
+               long length = dax_map_atomic(bdev, &dax);
+
                if (length < 0) {
                        result = VM_FAULT_SIGBUS;
                        goto out;
                }
-               if ((length < PMD_SIZE) || (pfn & PG_PMD_COLOUR))
+               if (length < PMD_SIZE) {
+                       dax_pmd_dbg(&bh, address, "dax-length too small");
+                       dax_unmap_atomic(bdev, &dax);
                        goto fallback;
+               }
+               if (pfn_t_to_pfn(dax.pfn) & PG_PMD_COLOUR) {
+                       dax_pmd_dbg(&bh, address, "pfn unaligned");
+                       dax_unmap_atomic(bdev, &dax);
+                       goto fallback;
+               }
 
-               /*
-                * TODO: teach vmf_insert_pfn_pmd() to support
-                * 'pte_special' for pmds
-                */
-               if (pfn_valid(pfn))
+               if (!pfn_t_devmap(dax.pfn)) {
+                       dax_unmap_atomic(bdev, &dax);
+                       dax_pmd_dbg(&bh, address, "pfn not in memmap");
                        goto fallback;
+               }
 
                if (buffer_unwritten(&bh) || buffer_new(&bh)) {
-                       int i;
-                       for (i = 0; i < PTRS_PER_PMD; i++)
-                               clear_pmem(kaddr + i * PAGE_SIZE, PAGE_SIZE);
+                       clear_pmem(dax.addr, PMD_SIZE);
                        wmb_pmem();
                        count_vm_event(PGMAJFAULT);
                        mem_cgroup_count_vm_event(vma->vm_mm, PGMAJFAULT);
                        result |= VM_FAULT_MAJOR;
                }
+               dax_unmap_atomic(bdev, &dax);
+
+               /*
+                * For PTE faults we insert a radix tree entry for reads, and
+                * leave it clean.  Then on the first write we dirty the radix
+                * tree entry via the dax_pfn_mkwrite() path.  This sequence
+                * allows the dax_pfn_mkwrite() call to be simpler and avoid a
+                * call into get_block() to translate the pgoff to a sector in
+                * order to be able to create a new radix tree entry.
+                *
+                * The PMD path doesn't have an equivalent to
+                * dax_pfn_mkwrite(), though, so for a read followed by a
+                * write we traverse all the way through __dax_pmd_fault()
+                * twice.  This means we can just skip inserting a radix tree
+                * entry completely on the initial read and just wait until
+                * the write to insert a dirty entry.
+                */
+               if (write) {
+                       error = dax_radix_entry(mapping, pgoff, dax.sector,
+                                       true, true);
+                       if (error) {
+                               dax_pmd_dbg(&bh, address,
+                                               "PMD radix insertion failed");
+                               goto fallback;
+                       }
+               }
 
-               result |= vmf_insert_pfn_pmd(vma, address, pmd, pfn, write);
+               dev_dbg(part_to_dev(bdev->bd_part),
+                               "%s: %s addr: %lx pfn: %lx sect: %llx\n",
+                               __func__, current->comm, address,
+                               pfn_t_to_pfn(dax.pfn),
+                               (unsigned long long) dax.sector);
+               result |= vmf_insert_pfn_pmd(vma, address, pmd,
+                               dax.pfn, write);
        }
 
  out:
@@ -702,15 +1028,20 @@ EXPORT_SYMBOL_GPL(dax_pmd_fault);
  * dax_pfn_mkwrite - handle first write to DAX page
  * @vma: The virtual memory area where the fault occurred
  * @vmf: The description of the fault
- *
  */
 int dax_pfn_mkwrite(struct vm_area_struct *vma, struct vm_fault *vmf)
 {
-       struct super_block *sb = file_inode(vma->vm_file)->i_sb;
+       struct file *file = vma->vm_file;
 
-       sb_start_pagefault(sb);
-       file_update_time(vma->vm_file);
-       sb_end_pagefault(sb);
+       /*
+        * We pass NO_SECTOR to dax_radix_entry() because we expect that a
+        * RADIX_DAX_PTE entry already exists in the radix tree from a
+        * previous call to __dax_fault().  We just want to look up that PTE
+        * entry using vmf->pgoff and make sure the dirty tag is set.  This
+        * saves us from having to make a call to get_block() here to look
+        * up the sector.
+        */
+       dax_radix_entry(file->f_mapping, vmf->pgoff, NO_SECTOR, false, true);
        return VM_FAULT_NOPAGE;
 }
 EXPORT_SYMBOL_GPL(dax_pfn_mkwrite);
@@ -752,12 +1083,17 @@ int dax_zero_page_range(struct inode *inode, loff_t from, unsigned length,
        if (err < 0)
                return err;
        if (buffer_written(&bh)) {
-               void __pmem *addr;
-               err = dax_get_addr(&bh, &addr, inode->i_blkbits);
-               if (err < 0)
-                       return err;
-               clear_pmem(addr + offset, length);
+               struct block_device *bdev = bh.b_bdev;
+               struct blk_dax_ctl dax = {
+                       .sector = to_sector(&bh, inode),
+                       .size = PAGE_CACHE_SIZE,
+               };
+
+               if (dax_map_atomic(bdev, &dax) < 0)
+                       return PTR_ERR(dax.addr);
+               clear_pmem(dax.addr + offset, length);
                wmb_pmem();
+               dax_unmap_atomic(bdev, &dax);
        }
 
        return 0;