From c7753208a94c73d5beb1e4bd843081d6dc7d4678 Mon Sep 17 00:00:00 2001 From: Tom Lendacky Date: Mon, 17 Jul 2017 16:10:21 -0500 Subject: x86, swiotlb: Add memory encryption support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Since DMA addresses will effectively look like 48-bit addresses when the memory encryption mask is set, SWIOTLB is needed if the DMA mask of the device performing the DMA does not support 48-bits. SWIOTLB will be initialized to create decrypted bounce buffers for use by these devices. Signed-off-by: Tom Lendacky Reviewed-by: Thomas Gleixner Cc: Alexander Potapenko Cc: Andrey Ryabinin Cc: Andy Lutomirski Cc: Arnd Bergmann Cc: Borislav Petkov Cc: Brijesh Singh Cc: Dave Young Cc: Dmitry Vyukov Cc: Jonathan Corbet Cc: Konrad Rzeszutek Wilk Cc: Larry Woodman Cc: Linus Torvalds Cc: Matt Fleming Cc: Michael S. Tsirkin Cc: Paolo Bonzini Cc: Peter Zijlstra Cc: Radim Krčmář Cc: Rik van Riel Cc: Toshimitsu Kani Cc: kasan-dev@googlegroups.com Cc: kvm@vger.kernel.org Cc: linux-arch@vger.kernel.org Cc: linux-doc@vger.kernel.org Cc: linux-efi@vger.kernel.org Cc: linux-mm@kvack.org Link: http://lkml.kernel.org/r/aa2d29b78ae7d508db8881e46a3215231b9327a7.1500319216.git.thomas.lendacky@amd.com Signed-off-by: Ingo Molnar --- lib/swiotlb.c | 54 ++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 8 deletions(-) (limited to 'lib') diff --git a/lib/swiotlb.c b/lib/swiotlb.c index a8d74a733a38..04ac91acf193 100644 --- a/lib/swiotlb.c +++ b/lib/swiotlb.c @@ -30,6 +30,7 @@ #include #include #include +#include #include #include @@ -155,6 +156,15 @@ unsigned long swiotlb_size_or_default(void) return size ? size : (IO_TLB_DEFAULT_SIZE); } +void __weak swiotlb_set_mem_attributes(void *vaddr, unsigned long size) { } + +/* For swiotlb, clear memory encryption mask from dma addresses */ +static dma_addr_t swiotlb_phys_to_dma(struct device *hwdev, + phys_addr_t address) +{ + return __sme_clr(phys_to_dma(hwdev, address)); +} + /* Note that this doesn't work with highmem page */ static dma_addr_t swiotlb_virt_to_bus(struct device *hwdev, volatile void *address) @@ -183,6 +193,31 @@ void swiotlb_print_info(void) bytes >> 20, vstart, vend - 1); } +/* + * Early SWIOTLB allocation may be too early to allow an architecture to + * perform the desired operations. This function allows the architecture to + * call SWIOTLB when the operations are possible. It needs to be called + * before the SWIOTLB memory is used. + */ +void __init swiotlb_update_mem_attributes(void) +{ + void *vaddr; + unsigned long bytes; + + if (no_iotlb_memory || late_alloc) + return; + + vaddr = phys_to_virt(io_tlb_start); + bytes = PAGE_ALIGN(io_tlb_nslabs << IO_TLB_SHIFT); + swiotlb_set_mem_attributes(vaddr, bytes); + memset(vaddr, 0, bytes); + + vaddr = phys_to_virt(io_tlb_overflow_buffer); + bytes = PAGE_ALIGN(io_tlb_overflow); + swiotlb_set_mem_attributes(vaddr, bytes); + memset(vaddr, 0, bytes); +} + int __init swiotlb_init_with_tbl(char *tlb, unsigned long nslabs, int verbose) { void *v_overflow_buffer; @@ -320,6 +355,7 @@ swiotlb_late_init_with_tbl(char *tlb, unsigned long nslabs) io_tlb_start = virt_to_phys(tlb); io_tlb_end = io_tlb_start + bytes; + swiotlb_set_mem_attributes(tlb, bytes); memset(tlb, 0, bytes); /* @@ -330,6 +366,8 @@ swiotlb_late_init_with_tbl(char *tlb, unsigned long nslabs) if (!v_overflow_buffer) goto cleanup2; + swiotlb_set_mem_attributes(v_overflow_buffer, io_tlb_overflow); + memset(v_overflow_buffer, 0, io_tlb_overflow); io_tlb_overflow_buffer = virt_to_phys(v_overflow_buffer); /* @@ -581,7 +619,7 @@ map_single(struct device *hwdev, phys_addr_t phys, size_t size, return SWIOTLB_MAP_ERROR; } - start_dma_addr = phys_to_dma(hwdev, io_tlb_start); + start_dma_addr = swiotlb_phys_to_dma(hwdev, io_tlb_start); return swiotlb_tbl_map_single(hwdev, start_dma_addr, phys, size, dir, attrs); } @@ -702,7 +740,7 @@ swiotlb_alloc_coherent(struct device *hwdev, size_t size, goto err_warn; ret = phys_to_virt(paddr); - dev_addr = phys_to_dma(hwdev, paddr); + dev_addr = swiotlb_phys_to_dma(hwdev, paddr); /* Confirm address can be DMA'd by device */ if (dev_addr + size - 1 > dma_mask) { @@ -812,10 +850,10 @@ dma_addr_t swiotlb_map_page(struct device *dev, struct page *page, map = map_single(dev, phys, size, dir, attrs); if (map == SWIOTLB_MAP_ERROR) { swiotlb_full(dev, size, dir, 1); - return phys_to_dma(dev, io_tlb_overflow_buffer); + return swiotlb_phys_to_dma(dev, io_tlb_overflow_buffer); } - dev_addr = phys_to_dma(dev, map); + dev_addr = swiotlb_phys_to_dma(dev, map); /* Ensure that the address returned is DMA'ble */ if (dma_capable(dev, dev_addr, size)) @@ -824,7 +862,7 @@ dma_addr_t swiotlb_map_page(struct device *dev, struct page *page, attrs |= DMA_ATTR_SKIP_CPU_SYNC; swiotlb_tbl_unmap_single(dev, map, size, dir, attrs); - return phys_to_dma(dev, io_tlb_overflow_buffer); + return swiotlb_phys_to_dma(dev, io_tlb_overflow_buffer); } EXPORT_SYMBOL_GPL(swiotlb_map_page); @@ -958,7 +996,7 @@ swiotlb_map_sg_attrs(struct device *hwdev, struct scatterlist *sgl, int nelems, sg_dma_len(sgl) = 0; return 0; } - sg->dma_address = phys_to_dma(hwdev, map); + sg->dma_address = swiotlb_phys_to_dma(hwdev, map); } else sg->dma_address = dev_addr; sg_dma_len(sg) = sg->length; @@ -1026,7 +1064,7 @@ EXPORT_SYMBOL(swiotlb_sync_sg_for_device); int swiotlb_dma_mapping_error(struct device *hwdev, dma_addr_t dma_addr) { - return (dma_addr == phys_to_dma(hwdev, io_tlb_overflow_buffer)); + return (dma_addr == swiotlb_phys_to_dma(hwdev, io_tlb_overflow_buffer)); } EXPORT_SYMBOL(swiotlb_dma_mapping_error); @@ -1039,6 +1077,6 @@ EXPORT_SYMBOL(swiotlb_dma_mapping_error); int swiotlb_dma_supported(struct device *hwdev, u64 mask) { - return phys_to_dma(hwdev, io_tlb_end - 1) <= mask; + return swiotlb_phys_to_dma(hwdev, io_tlb_end - 1) <= mask; } EXPORT_SYMBOL(swiotlb_dma_supported); -- cgit v1.2.3 From 648babb7078c6310d2af5b8aa01f086030916968 Mon Sep 17 00:00:00 2001 From: Tom Lendacky Date: Mon, 17 Jul 2017 16:10:22 -0500 Subject: swiotlb: Add warnings for use of bounce buffers with SME MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add warnings to let the user know when bounce buffers are being used for DMA when SME is active. Since the bounce buffers are not in encrypted memory, these notifications are to allow the user to determine some appropriate action - if necessary. Actions can range from utilizing an IOMMU, replacing the device with another device that can support 64-bit DMA, ignoring the message if the device isn't used much, etc. Signed-off-by: Tom Lendacky Reviewed-by: Thomas Gleixner Cc: Alexander Potapenko Cc: Andrey Ryabinin Cc: Andy Lutomirski Cc: Arnd Bergmann Cc: Borislav Petkov Cc: Brijesh Singh Cc: Dave Young Cc: Dmitry Vyukov Cc: Jonathan Corbet Cc: Konrad Rzeszutek Wilk Cc: Larry Woodman Cc: Linus Torvalds Cc: Matt Fleming Cc: Michael S. Tsirkin Cc: Paolo Bonzini Cc: Peter Zijlstra Cc: Radim Krčmář Cc: Rik van Riel Cc: Toshimitsu Kani Cc: kasan-dev@googlegroups.com Cc: kvm@vger.kernel.org Cc: linux-arch@vger.kernel.org Cc: linux-doc@vger.kernel.org Cc: linux-efi@vger.kernel.org Cc: linux-mm@kvack.org Link: http://lkml.kernel.org/r/d112564053c3f2e86ca634a8d4fa4abc0eb53a6a.1500319216.git.thomas.lendacky@amd.com Signed-off-by: Ingo Molnar --- include/linux/dma-mapping.h | 13 +++++++++++++ lib/swiotlb.c | 3 +++ 2 files changed, 16 insertions(+) (limited to 'lib') diff --git a/include/linux/dma-mapping.h b/include/linux/dma-mapping.h index 843ab866e0f4..fce2369ecf82 100644 --- a/include/linux/dma-mapping.h +++ b/include/linux/dma-mapping.h @@ -10,6 +10,7 @@ #include #include #include +#include /** * List of possible attributes associated with a DMA mapping. The semantics @@ -548,6 +549,12 @@ static inline int dma_mapping_error(struct device *dev, dma_addr_t dma_addr) return 0; } +static inline void dma_check_mask(struct device *dev, u64 mask) +{ + if (sme_active() && (mask < (((u64)sme_get_me_mask() << 1) - 1))) + dev_warn(dev, "SME is active, device will require DMA bounce buffers\n"); +} + static inline int dma_supported(struct device *dev, u64 mask) { const struct dma_map_ops *ops = get_dma_ops(dev); @@ -564,6 +571,9 @@ static inline int dma_set_mask(struct device *dev, u64 mask) { if (!dev->dma_mask || !dma_supported(dev, mask)) return -EIO; + + dma_check_mask(dev, mask); + *dev->dma_mask = mask; return 0; } @@ -583,6 +593,9 @@ static inline int dma_set_coherent_mask(struct device *dev, u64 mask) { if (!dma_supported(dev, mask)) return -EIO; + + dma_check_mask(dev, mask); + dev->coherent_dma_mask = mask; return 0; } diff --git a/lib/swiotlb.c b/lib/swiotlb.c index 04ac91acf193..8c6c83ef57a4 100644 --- a/lib/swiotlb.c +++ b/lib/swiotlb.c @@ -507,6 +507,9 @@ phys_addr_t swiotlb_tbl_map_single(struct device *hwdev, if (no_iotlb_memory) panic("Can not allocate SWIOTLB buffer earlier and can't now provide you with the DMA bounce buffer"); + if (sme_active()) + pr_warn_once("SME is active and system is using DMA bounce buffers\n"); + mask = dma_get_seg_boundary(hwdev); tbl_dma_addr &= mask; -- cgit v1.2.3 From 1455cf8dbfd06aa7651dcfccbadb7a093944ca65 Mon Sep 17 00:00:00 2001 From: Dmitry Torokhov Date: Wed, 19 Jul 2017 17:24:30 -0700 Subject: driver core: emit uevents when device is bound to a driver There are certain touch controllers that may come up in either normal (application) or boot mode, depending on whether firmware/configuration is corrupted when they are powered on. In boot mode the kernel does not create input device instance (because it does not necessarily know the characteristics of the input device in question). Another number of controllers does not store firmware in a non-volatile memory, and they similarly need to have firmware loaded before input device instance is created. There are also other types of devices with similar behavior. There is a desire to be able to trigger firmware loading via udev, but it has to happen only when driver is bound to a physical device (i2c or spi). These udev actions can not use ADD events, as those happen too early, so we are introducing BIND and UNBIND events that are emitted at the right moment. Also, many drivers create additional driver-specific device attributes when binding to the device, to provide userspace with additional controls. The new events allow userspace to adjust these driver-specific attributes without worrying that they are not there yet. Signed-off-by: Dmitry Torokhov Signed-off-by: Greg Kroah-Hartman --- drivers/base/dd.c | 4 ++++ include/linux/kobject.h | 2 ++ lib/kobject_uevent.c | 2 ++ 3 files changed, 8 insertions(+) (limited to 'lib') diff --git a/drivers/base/dd.c b/drivers/base/dd.c index 4882f06d12df..c17fefc77345 100644 --- a/drivers/base/dd.c +++ b/drivers/base/dd.c @@ -259,6 +259,8 @@ static void driver_bound(struct device *dev) if (dev->bus) blocking_notifier_call_chain(&dev->bus->p->bus_notifier, BUS_NOTIFY_BOUND_DRIVER, dev); + + kobject_uevent(&dev->kobj, KOBJ_BIND); } static int driver_sysfs_add(struct device *dev) @@ -848,6 +850,8 @@ static void __device_release_driver(struct device *dev, struct device *parent) blocking_notifier_call_chain(&dev->bus->p->bus_notifier, BUS_NOTIFY_UNBOUND_DRIVER, dev); + + kobject_uevent(&dev->kobj, KOBJ_UNBIND); } } diff --git a/include/linux/kobject.h b/include/linux/kobject.h index ca85cb80e99a..12f5ddccb97c 100644 --- a/include/linux/kobject.h +++ b/include/linux/kobject.h @@ -57,6 +57,8 @@ enum kobject_action { KOBJ_MOVE, KOBJ_ONLINE, KOBJ_OFFLINE, + KOBJ_BIND, + KOBJ_UNBIND, KOBJ_MAX }; diff --git a/lib/kobject_uevent.c b/lib/kobject_uevent.c index 9a2b811966eb..4682e8545b5c 100644 --- a/lib/kobject_uevent.c +++ b/lib/kobject_uevent.c @@ -50,6 +50,8 @@ static const char *kobject_actions[] = { [KOBJ_MOVE] = "move", [KOBJ_ONLINE] = "online", [KOBJ_OFFLINE] = "offline", + [KOBJ_BIND] = "bind", + [KOBJ_UNBIND] = "unbind", }; /** -- cgit v1.2.3 From ee9f8fce99640811b2b8e79d0d1dbe8bab69ba67 Mon Sep 17 00:00:00 2001 From: Josh Poimboeuf Date: Mon, 24 Jul 2017 18:36:57 -0500 Subject: x86/unwind: Add the ORC unwinder Add the new ORC unwinder which is enabled by CONFIG_ORC_UNWINDER=y. It plugs into the existing x86 unwinder framework. It relies on objtool to generate the needed .orc_unwind and .orc_unwind_ip sections. For more details on why ORC is used instead of DWARF, see Documentation/x86/orc-unwinder.txt - but the short version is that it's a simplified, fundamentally more robust debugninfo data structure, which also allows up to two orders of magnitude faster lookups than the DWARF unwinder - which matters to profiling workloads like perf. Thanks to Andy Lutomirski for the performance improvement ideas: splitting the ORC unwind table into two parallel arrays and creating a fast lookup table to search a subset of the unwind table. Signed-off-by: Josh Poimboeuf Cc: Andy Lutomirski Cc: Borislav Petkov Cc: Brian Gerst Cc: Denys Vlasenko Cc: H. Peter Anvin Cc: Jiri Slaby Cc: Linus Torvalds Cc: Mike Galbraith Cc: Peter Zijlstra Cc: Thomas Gleixner Cc: live-patching@vger.kernel.org Link: http://lkml.kernel.org/r/0a6cbfb40f8da99b7a45a1a8302dc6aef16ec812.1500938583.git.jpoimboe@redhat.com [ Extended the changelog. ] Signed-off-by: Ingo Molnar --- Documentation/x86/orc-unwinder.txt | 179 ++++++++++++ arch/um/include/asm/unwind.h | 8 + arch/x86/Kconfig | 1 + arch/x86/Kconfig.debug | 25 ++ arch/x86/include/asm/module.h | 9 + arch/x86/include/asm/orc_lookup.h | 46 +++ arch/x86/include/asm/orc_types.h | 2 +- arch/x86/include/asm/unwind.h | 76 +++-- arch/x86/kernel/Makefile | 8 +- arch/x86/kernel/module.c | 11 +- arch/x86/kernel/setup.c | 3 + arch/x86/kernel/unwind_frame.c | 39 +-- arch/x86/kernel/unwind_guess.c | 5 + arch/x86/kernel/unwind_orc.c | 582 +++++++++++++++++++++++++++++++++++++ arch/x86/kernel/vmlinux.lds.S | 3 + include/asm-generic/vmlinux.lds.h | 27 +- lib/Kconfig.debug | 3 + scripts/Makefile.build | 14 +- 18 files changed, 977 insertions(+), 64 deletions(-) create mode 100644 Documentation/x86/orc-unwinder.txt create mode 100644 arch/um/include/asm/unwind.h create mode 100644 arch/x86/include/asm/orc_lookup.h create mode 100644 arch/x86/kernel/unwind_orc.c (limited to 'lib') diff --git a/Documentation/x86/orc-unwinder.txt b/Documentation/x86/orc-unwinder.txt new file mode 100644 index 000000000000..af0c9a4c65a6 --- /dev/null +++ b/Documentation/x86/orc-unwinder.txt @@ -0,0 +1,179 @@ +ORC unwinder +============ + +Overview +-------- + +The kernel CONFIG_ORC_UNWINDER option enables the ORC unwinder, which is +similar in concept to a DWARF unwinder. The difference is that the +format of the ORC data is much simpler than DWARF, which in turn allows +the ORC unwinder to be much simpler and faster. + +The ORC data consists of unwind tables which are generated by objtool. +They contain out-of-band data which is used by the in-kernel ORC +unwinder. Objtool generates the ORC data by first doing compile-time +stack metadata validation (CONFIG_STACK_VALIDATION). After analyzing +all the code paths of a .o file, it determines information about the +stack state at each instruction address in the file and outputs that +information to the .orc_unwind and .orc_unwind_ip sections. + +The per-object ORC sections are combined at link time and are sorted and +post-processed at boot time. The unwinder uses the resulting data to +correlate instruction addresses with their stack states at run time. + + +ORC vs frame pointers +--------------------- + +With frame pointers enabled, GCC adds instrumentation code to every +function in the kernel. The kernel's .text size increases by about +3.2%, resulting in a broad kernel-wide slowdown. Measurements by Mel +Gorman [1] have shown a slowdown of 5-10% for some workloads. + +In contrast, the ORC unwinder has no effect on text size or runtime +performance, because the debuginfo is out of band. So if you disable +frame pointers and enable the ORC unwinder, you get a nice performance +improvement across the board, and still have reliable stack traces. + +Ingo Molnar says: + + "Note that it's not just a performance improvement, but also an + instruction cache locality improvement: 3.2% .text savings almost + directly transform into a similarly sized reduction in cache + footprint. That can transform to even higher speedups for workloads + whose cache locality is borderline." + +Another benefit of ORC compared to frame pointers is that it can +reliably unwind across interrupts and exceptions. Frame pointer based +unwinds can sometimes skip the caller of the interrupted function, if it +was a leaf function or if the interrupt hit before the frame pointer was +saved. + +The main disadvantage of the ORC unwinder compared to frame pointers is +that it needs more memory to store the ORC unwind tables: roughly 2-4MB +depending on the kernel config. + + +ORC vs DWARF +------------ + +ORC debuginfo's advantage over DWARF itself is that it's much simpler. +It gets rid of the complex DWARF CFI state machine and also gets rid of +the tracking of unnecessary registers. This allows the unwinder to be +much simpler, meaning fewer bugs, which is especially important for +mission critical oops code. + +The simpler debuginfo format also enables the unwinder to be much faster +than DWARF, which is important for perf and lockdep. In a basic +performance test by Jiri Slaby [2], the ORC unwinder was about 20x +faster than an out-of-tree DWARF unwinder. (Note: That measurement was +taken before some performance tweaks were added, which doubled +performance, so the speedup over DWARF may be closer to 40x.) + +The ORC data format does have a few downsides compared to DWARF. ORC +unwind tables take up ~50% more RAM (+1.3MB on an x86 defconfig kernel) +than DWARF-based eh_frame tables. + +Another potential downside is that, as GCC evolves, it's conceivable +that the ORC data may end up being *too* simple to describe the state of +the stack for certain optimizations. But IMO this is unlikely because +GCC saves the frame pointer for any unusual stack adjustments it does, +so I suspect we'll really only ever need to keep track of the stack +pointer and the frame pointer between call frames. But even if we do +end up having to track all the registers DWARF tracks, at least we will +still be able to control the format, e.g. no complex state machines. + + +ORC unwind table generation +--------------------------- + +The ORC data is generated by objtool. With the existing compile-time +stack metadata validation feature, objtool already follows all code +paths, and so it already has all the information it needs to be able to +generate ORC data from scratch. So it's an easy step to go from stack +validation to ORC data generation. + +It should be possible to instead generate the ORC data with a simple +tool which converts DWARF to ORC data. However, such a solution would +be incomplete due to the kernel's extensive use of asm, inline asm, and +special sections like exception tables. + +That could be rectified by manually annotating those special code paths +using GNU assembler .cfi annotations in .S files, and homegrown +annotations for inline asm in .c files. But asm annotations were tried +in the past and were found to be unmaintainable. They were often +incorrect/incomplete and made the code harder to read and keep updated. +And based on looking at glibc code, annotating inline asm in .c files +might be even worse. + +Objtool still needs a few annotations, but only in code which does +unusual things to the stack like entry code. And even then, far fewer +annotations are needed than what DWARF would need, so they're much more +maintainable than DWARF CFI annotations. + +So the advantages of using objtool to generate ORC data are that it +gives more accurate debuginfo, with very few annotations. It also +insulates the kernel from toolchain bugs which can be very painful to +deal with in the kernel since we often have to workaround issues in +older versions of the toolchain for years. + +The downside is that the unwinder now becomes dependent on objtool's +ability to reverse engineer GCC code flow. If GCC optimizations become +too complicated for objtool to follow, the ORC data generation might +stop working or become incomplete. (It's worth noting that livepatch +already has such a dependency on objtool's ability to follow GCC code +flow.) + +If newer versions of GCC come up with some optimizations which break +objtool, we may need to revisit the current implementation. Some +possible solutions would be asking GCC to make the optimizations more +palatable, or having objtool use DWARF as an additional input, or +creating a GCC plugin to assist objtool with its analysis. But for now, +objtool follows GCC code quite well. + + +Unwinder implementation details +------------------------------- + +Objtool generates the ORC data by integrating with the compile-time +stack metadata validation feature, which is described in detail in +tools/objtool/Documentation/stack-validation.txt. After analyzing all +the code paths of a .o file, it creates an array of orc_entry structs, +and a parallel array of instruction addresses associated with those +structs, and writes them to the .orc_unwind and .orc_unwind_ip sections +respectively. + +The ORC data is split into the two arrays for performance reasons, to +make the searchable part of the data (.orc_unwind_ip) more compact. The +arrays are sorted in parallel at boot time. + +Performance is further improved by the use of a fast lookup table which +is created at runtime. The fast lookup table associates a given address +with a range of indices for the .orc_unwind table, so that only a small +subset of the table needs to be searched. + + +Etymology +--------- + +Orcs, fearsome creatures of medieval folklore, are the Dwarves' natural +enemies. Similarly, the ORC unwinder was created in opposition to the +complexity and slowness of DWARF. + +"Although Orcs rarely consider multiple solutions to a problem, they do +excel at getting things done because they are creatures of action, not +thought." [3] Similarly, unlike the esoteric DWARF unwinder, the +veracious ORC unwinder wastes no time or siloconic effort decoding +variable-length zero-extended unsigned-integer byte-coded +state-machine-based debug information entries. + +Similar to how Orcs frequently unravel the well-intentioned plans of +their adversaries, the ORC unwinder frequently unravels stacks with +brutal, unyielding efficiency. + +ORC stands for Oops Rewind Capability. + + +[1] https://lkml.kernel.org/r/20170602104048.jkkzssljsompjdwy@suse.de +[2] https://lkml.kernel.org/r/d2ca5435-6386-29b8-db87-7f227c2b713a@suse.cz +[3] http://dustin.wikidot.com/half-orcs-and-orcs diff --git a/arch/um/include/asm/unwind.h b/arch/um/include/asm/unwind.h new file mode 100644 index 000000000000..7ffa5437b761 --- /dev/null +++ b/arch/um/include/asm/unwind.h @@ -0,0 +1,8 @@ +#ifndef _ASM_UML_UNWIND_H +#define _ASM_UML_UNWIND_H + +static inline void +unwind_module_init(struct module *mod, void *orc_ip, size_t orc_ip_size, + void *orc, size_t orc_size) {} + +#endif /* _ASM_UML_UNWIND_H */ diff --git a/arch/x86/Kconfig b/arch/x86/Kconfig index 781521b7cf9e..7ccf26a427cb 100644 --- a/arch/x86/Kconfig +++ b/arch/x86/Kconfig @@ -157,6 +157,7 @@ config X86 select HAVE_MEMBLOCK select HAVE_MEMBLOCK_NODE_MAP select HAVE_MIXED_BREAKPOINTS_REGS + select HAVE_MOD_ARCH_SPECIFIC select HAVE_NMI select HAVE_OPROFILE select HAVE_OPTPROBES diff --git a/arch/x86/Kconfig.debug b/arch/x86/Kconfig.debug index 353ed09f5fba..dc10ec6ead6e 100644 --- a/arch/x86/Kconfig.debug +++ b/arch/x86/Kconfig.debug @@ -355,4 +355,29 @@ config PUNIT_ATOM_DEBUG The current power state can be read from /sys/kernel/debug/punit_atom/dev_power_state +config ORC_UNWINDER + bool "ORC unwinder" + depends on X86_64 + select STACK_VALIDATION + ---help--- + This option enables the ORC (Oops Rewind Capability) unwinder for + unwinding kernel stack traces. It uses a custom data format which is + a simplified version of the DWARF Call Frame Information standard. + + This unwinder is more accurate across interrupt entry frames than the + frame pointer unwinder. It can also enable a 5-10% performance + improvement across the entire kernel if CONFIG_FRAME_POINTER is + disabled. + + Enabling this option will increase the kernel's runtime memory usage + by roughly 2-4MB, depending on your kernel config. + +config FRAME_POINTER_UNWINDER + def_bool y + depends on !ORC_UNWINDER && FRAME_POINTER + +config GUESS_UNWINDER + def_bool y + depends on !ORC_UNWINDER && !FRAME_POINTER + endmenu diff --git a/arch/x86/include/asm/module.h b/arch/x86/include/asm/module.h index e3b7819caeef..9eb7c718aaf8 100644 --- a/arch/x86/include/asm/module.h +++ b/arch/x86/include/asm/module.h @@ -2,6 +2,15 @@ #define _ASM_X86_MODULE_H #include +#include + +struct mod_arch_specific { +#ifdef CONFIG_ORC_UNWINDER + unsigned int num_orcs; + int *orc_unwind_ip; + struct orc_entry *orc_unwind; +#endif +}; #ifdef CONFIG_X86_64 /* X86_64 does not define MODULE_PROC_FAMILY */ diff --git a/arch/x86/include/asm/orc_lookup.h b/arch/x86/include/asm/orc_lookup.h new file mode 100644 index 000000000000..91c8d868424d --- /dev/null +++ b/arch/x86/include/asm/orc_lookup.h @@ -0,0 +1,46 @@ +/* + * Copyright (C) 2017 Josh Poimboeuf + * + * This program is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License + * as published by the Free Software Foundation; either version 2 + * of the License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, see . + */ +#ifndef _ORC_LOOKUP_H +#define _ORC_LOOKUP_H + +/* + * This is a lookup table for speeding up access to the .orc_unwind table. + * Given an input address offset, the corresponding lookup table entry + * specifies a subset of the .orc_unwind table to search. + * + * Each block represents the end of the previous range and the start of the + * next range. An extra block is added to give the last range an end. + * + * The block size should be a power of 2 to avoid a costly 'div' instruction. + * + * A block size of 256 was chosen because it roughly doubles unwinder + * performance while only adding ~5% to the ORC data footprint. + */ +#define LOOKUP_BLOCK_ORDER 8 +#define LOOKUP_BLOCK_SIZE (1 << LOOKUP_BLOCK_ORDER) + +#ifndef LINKER_SCRIPT + +extern unsigned int orc_lookup[]; +extern unsigned int orc_lookup_end[]; + +#define LOOKUP_START_IP (unsigned long)_stext +#define LOOKUP_STOP_IP (unsigned long)_etext + +#endif /* LINKER_SCRIPT */ + +#endif /* _ORC_LOOKUP_H */ diff --git a/arch/x86/include/asm/orc_types.h b/arch/x86/include/asm/orc_types.h index 7dc777a6cb40..9c9dc579bd7d 100644 --- a/arch/x86/include/asm/orc_types.h +++ b/arch/x86/include/asm/orc_types.h @@ -88,7 +88,7 @@ struct orc_entry { unsigned sp_reg:4; unsigned bp_reg:4; unsigned type:2; -}; +} __packed; /* * This struct is used by asm and inline asm code to manually annotate the diff --git a/arch/x86/include/asm/unwind.h b/arch/x86/include/asm/unwind.h index e6676495b125..25b8d31a007d 100644 --- a/arch/x86/include/asm/unwind.h +++ b/arch/x86/include/asm/unwind.h @@ -12,11 +12,14 @@ struct unwind_state { struct task_struct *task; int graph_idx; bool error; -#ifdef CONFIG_FRAME_POINTER +#if defined(CONFIG_ORC_UNWINDER) + bool signal, full_regs; + unsigned long sp, bp, ip; + struct pt_regs *regs; +#elif defined(CONFIG_FRAME_POINTER) bool got_irq; - unsigned long *bp, *orig_sp; + unsigned long *bp, *orig_sp, ip; struct pt_regs *regs; - unsigned long ip; #else unsigned long *sp; #endif @@ -24,41 +27,30 @@ struct unwind_state { void __unwind_start(struct unwind_state *state, struct task_struct *task, struct pt_regs *regs, unsigned long *first_frame); - bool unwind_next_frame(struct unwind_state *state); - unsigned long unwind_get_return_address(struct unwind_state *state); +unsigned long *unwind_get_return_address_ptr(struct unwind_state *state); static inline bool unwind_done(struct unwind_state *state) { return state->stack_info.type == STACK_TYPE_UNKNOWN; } -static inline -void unwind_start(struct unwind_state *state, struct task_struct *task, - struct pt_regs *regs, unsigned long *first_frame) -{ - first_frame = first_frame ? : get_stack_pointer(task, regs); - - __unwind_start(state, task, regs, first_frame); -} - static inline bool unwind_error(struct unwind_state *state) { return state->error; } -#ifdef CONFIG_FRAME_POINTER - static inline -unsigned long *unwind_get_return_address_ptr(struct unwind_state *state) +void unwind_start(struct unwind_state *state, struct task_struct *task, + struct pt_regs *regs, unsigned long *first_frame) { - if (unwind_done(state)) - return NULL; + first_frame = first_frame ? : get_stack_pointer(task, regs); - return state->regs ? &state->regs->ip : state->bp + 1; + __unwind_start(state, task, regs, first_frame); } +#if defined(CONFIG_ORC_UNWINDER) || defined(CONFIG_FRAME_POINTER) static inline struct pt_regs *unwind_get_entry_regs(struct unwind_state *state) { if (unwind_done(state)) @@ -66,20 +58,46 @@ static inline struct pt_regs *unwind_get_entry_regs(struct unwind_state *state) return state->regs; } - -#else /* !CONFIG_FRAME_POINTER */ - -static inline -unsigned long *unwind_get_return_address_ptr(struct unwind_state *state) +#else +static inline struct pt_regs *unwind_get_entry_regs(struct unwind_state *state) { return NULL; } +#endif -static inline struct pt_regs *unwind_get_entry_regs(struct unwind_state *state) +#ifdef CONFIG_ORC_UNWINDER +void unwind_init(void); +void unwind_module_init(struct module *mod, void *orc_ip, size_t orc_ip_size, + void *orc, size_t orc_size); +#else +static inline void unwind_init(void) {} +static inline +void unwind_module_init(struct module *mod, void *orc_ip, size_t orc_ip_size, + void *orc, size_t orc_size) {} +#endif + +/* + * This disables KASAN checking when reading a value from another task's stack, + * since the other task could be running on another CPU and could have poisoned + * the stack in the meantime. + */ +#define READ_ONCE_TASK_STACK(task, x) \ +({ \ + unsigned long val; \ + if (task == current) \ + val = READ_ONCE(x); \ + else \ + val = READ_ONCE_NOCHECK(x); \ + val; \ +}) + +static inline bool task_on_another_cpu(struct task_struct *task) { - return NULL; +#ifdef CONFIG_SMP + return task != current && task->on_cpu; +#else + return false; +#endif } -#endif /* CONFIG_FRAME_POINTER */ - #endif /* _ASM_X86_UNWIND_H */ diff --git a/arch/x86/kernel/Makefile b/arch/x86/kernel/Makefile index a01892bdd61a..287eac7d207f 100644 --- a/arch/x86/kernel/Makefile +++ b/arch/x86/kernel/Makefile @@ -126,11 +126,9 @@ obj-$(CONFIG_PERF_EVENTS) += perf_regs.o obj-$(CONFIG_TRACING) += tracepoint.o obj-$(CONFIG_SCHED_MC_PRIO) += itmt.o -ifdef CONFIG_FRAME_POINTER -obj-y += unwind_frame.o -else -obj-y += unwind_guess.o -endif +obj-$(CONFIG_ORC_UNWINDER) += unwind_orc.o +obj-$(CONFIG_FRAME_POINTER_UNWINDER) += unwind_frame.o +obj-$(CONFIG_GUESS_UNWINDER) += unwind_guess.o ### # 64 bit specific files diff --git a/arch/x86/kernel/module.c b/arch/x86/kernel/module.c index f67bd3205df7..62e7d70aadd5 100644 --- a/arch/x86/kernel/module.c +++ b/arch/x86/kernel/module.c @@ -35,6 +35,7 @@ #include #include #include +#include #if 0 #define DEBUGP(fmt, ...) \ @@ -213,7 +214,7 @@ int module_finalize(const Elf_Ehdr *hdr, struct module *me) { const Elf_Shdr *s, *text = NULL, *alt = NULL, *locks = NULL, - *para = NULL; + *para = NULL, *orc = NULL, *orc_ip = NULL; char *secstrings = (void *)hdr + sechdrs[hdr->e_shstrndx].sh_offset; for (s = sechdrs; s < sechdrs + hdr->e_shnum; s++) { @@ -225,6 +226,10 @@ int module_finalize(const Elf_Ehdr *hdr, locks = s; if (!strcmp(".parainstructions", secstrings + s->sh_name)) para = s; + if (!strcmp(".orc_unwind", secstrings + s->sh_name)) + orc = s; + if (!strcmp(".orc_unwind_ip", secstrings + s->sh_name)) + orc_ip = s; } if (alt) { @@ -248,6 +253,10 @@ int module_finalize(const Elf_Ehdr *hdr, /* make jump label nops */ jump_label_apply_nops(me); + if (orc && orc_ip) + unwind_module_init(me, (void *)orc_ip->sh_addr, orc_ip->sh_size, + (void *)orc->sh_addr, orc->sh_size); + return 0; } diff --git a/arch/x86/kernel/setup.c b/arch/x86/kernel/setup.c index 3486d0498800..ecab32282f0f 100644 --- a/arch/x86/kernel/setup.c +++ b/arch/x86/kernel/setup.c @@ -115,6 +115,7 @@ #include #include #include +#include /* * max_low_pfn_mapped: highest direct mapped pfn under 4GB @@ -1310,6 +1311,8 @@ void __init setup_arch(char **cmdline_p) if (efi_enabled(EFI_BOOT)) efi_apply_memmap_quirks(); #endif + + unwind_init(); } #ifdef CONFIG_X86_32 diff --git a/arch/x86/kernel/unwind_frame.c b/arch/x86/kernel/unwind_frame.c index b9389d72b2f7..7574ef5f16ec 100644 --- a/arch/x86/kernel/unwind_frame.c +++ b/arch/x86/kernel/unwind_frame.c @@ -10,20 +10,22 @@ #define FRAME_HEADER_SIZE (sizeof(long) * 2) -/* - * This disables KASAN checking when reading a value from another task's stack, - * since the other task could be running on another CPU and could have poisoned - * the stack in the meantime. - */ -#define READ_ONCE_TASK_STACK(task, x) \ -({ \ - unsigned long val; \ - if (task == current) \ - val = READ_ONCE(x); \ - else \ - val = READ_ONCE_NOCHECK(x); \ - val; \ -}) +unsigned long unwind_get_return_address(struct unwind_state *state) +{ + if (unwind_done(state)) + return 0; + + return __kernel_text_address(state->ip) ? state->ip : 0; +} +EXPORT_SYMBOL_GPL(unwind_get_return_address); + +unsigned long *unwind_get_return_address_ptr(struct unwind_state *state) +{ + if (unwind_done(state)) + return NULL; + + return state->regs ? &state->regs->ip : state->bp + 1; +} static void unwind_dump(struct unwind_state *state) { @@ -66,15 +68,6 @@ static void unwind_dump(struct unwind_state *state) } } -unsigned long unwind_get_return_address(struct unwind_state *state) -{ - if (unwind_done(state)) - return 0; - - return __kernel_text_address(state->ip) ? state->ip : 0; -} -EXPORT_SYMBOL_GPL(unwind_get_return_address); - static size_t regs_size(struct pt_regs *regs) { /* x86_32 regs from kernel mode are two words shorter: */ diff --git a/arch/x86/kernel/unwind_guess.c b/arch/x86/kernel/unwind_guess.c index 039f36738e49..4f0e17b90463 100644 --- a/arch/x86/kernel/unwind_guess.c +++ b/arch/x86/kernel/unwind_guess.c @@ -19,6 +19,11 @@ unsigned long unwind_get_return_address(struct unwind_state *state) } EXPORT_SYMBOL_GPL(unwind_get_return_address); +unsigned long *unwind_get_return_address_ptr(struct unwind_state *state) +{ + return NULL; +} + bool unwind_next_frame(struct unwind_state *state) { struct stack_info *info = &state->stack_info; diff --git a/arch/x86/kernel/unwind_orc.c b/arch/x86/kernel/unwind_orc.c new file mode 100644 index 000000000000..570b70d3f604 --- /dev/null +++ b/arch/x86/kernel/unwind_orc.c @@ -0,0 +1,582 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#define orc_warn(fmt, ...) \ + printk_deferred_once(KERN_WARNING pr_fmt("WARNING: " fmt), ##__VA_ARGS__) + +extern int __start_orc_unwind_ip[]; +extern int __stop_orc_unwind_ip[]; +extern struct orc_entry __start_orc_unwind[]; +extern struct orc_entry __stop_orc_unwind[]; + +static DEFINE_MUTEX(sort_mutex); +int *cur_orc_ip_table = __start_orc_unwind_ip; +struct orc_entry *cur_orc_table = __start_orc_unwind; + +unsigned int lookup_num_blocks; +bool orc_init; + +static inline unsigned long orc_ip(const int *ip) +{ + return (unsigned long)ip + *ip; +} + +static struct orc_entry *__orc_find(int *ip_table, struct orc_entry *u_table, + unsigned int num_entries, unsigned long ip) +{ + int *first = ip_table; + int *last = ip_table + num_entries - 1; + int *mid = first, *found = first; + + if (!num_entries) + return NULL; + + /* + * Do a binary range search to find the rightmost duplicate of a given + * starting address. Some entries are section terminators which are + * "weak" entries for ensuring there are no gaps. They should be + * ignored when they conflict with a real entry. + */ + while (first <= last) { + mid = first + ((last - first) / 2); + + if (orc_ip(mid) <= ip) { + found = mid; + first = mid + 1; + } else + last = mid - 1; + } + + return u_table + (found - ip_table); +} + +#ifdef CONFIG_MODULES +static struct orc_entry *orc_module_find(unsigned long ip) +{ + struct module *mod; + + mod = __module_address(ip); + if (!mod || !mod->arch.orc_unwind || !mod->arch.orc_unwind_ip) + return NULL; + return __orc_find(mod->arch.orc_unwind_ip, mod->arch.orc_unwind, + mod->arch.num_orcs, ip); +} +#else +static struct orc_entry *orc_module_find(unsigned long ip) +{ + return NULL; +} +#endif + +static struct orc_entry *orc_find(unsigned long ip) +{ + if (!orc_init) + return NULL; + + /* For non-init vmlinux addresses, use the fast lookup table: */ + if (ip >= LOOKUP_START_IP && ip < LOOKUP_STOP_IP) { + unsigned int idx, start, stop; + + idx = (ip - LOOKUP_START_IP) / LOOKUP_BLOCK_SIZE; + + if (unlikely((idx >= lookup_num_blocks-1))) { + orc_warn("WARNING: bad lookup idx: idx=%u num=%u ip=%lx\n", + idx, lookup_num_blocks, ip); + return NULL; + } + + start = orc_lookup[idx]; + stop = orc_lookup[idx + 1] + 1; + + if (unlikely((__start_orc_unwind + start >= __stop_orc_unwind) || + (__start_orc_unwind + stop > __stop_orc_unwind))) { + orc_warn("WARNING: bad lookup value: idx=%u num=%u start=%u stop=%u ip=%lx\n", + idx, lookup_num_blocks, start, stop, ip); + return NULL; + } + + return __orc_find(__start_orc_unwind_ip + start, + __start_orc_unwind + start, stop - start, ip); + } + + /* vmlinux .init slow lookup: */ + if (ip >= (unsigned long)_sinittext && ip < (unsigned long)_einittext) + return __orc_find(__start_orc_unwind_ip, __start_orc_unwind, + __stop_orc_unwind_ip - __start_orc_unwind_ip, ip); + + /* Module lookup: */ + return orc_module_find(ip); +} + +static void orc_sort_swap(void *_a, void *_b, int size) +{ + struct orc_entry *orc_a, *orc_b; + struct orc_entry orc_tmp; + int *a = _a, *b = _b, tmp; + int delta = _b - _a; + + /* Swap the .orc_unwind_ip entries: */ + tmp = *a; + *a = *b + delta; + *b = tmp - delta; + + /* Swap the corresponding .orc_unwind entries: */ + orc_a = cur_orc_table + (a - cur_orc_ip_table); + orc_b = cur_orc_table + (b - cur_orc_ip_table); + orc_tmp = *orc_a; + *orc_a = *orc_b; + *orc_b = orc_tmp; +} + +static int orc_sort_cmp(const void *_a, const void *_b) +{ + struct orc_entry *orc_a; + const int *a = _a, *b = _b; + unsigned long a_val = orc_ip(a); + unsigned long b_val = orc_ip(b); + + if (a_val > b_val) + return 1; + if (a_val < b_val) + return -1; + + /* + * The "weak" section terminator entries need to always be on the left + * to ensure the lookup code skips them in favor of real entries. + * These terminator entries exist to handle any gaps created by + * whitelisted .o files which didn't get objtool generation. + */ + orc_a = cur_orc_table + (a - cur_orc_ip_table); + return orc_a->sp_reg == ORC_REG_UNDEFINED ? -1 : 1; +} + +#ifdef CONFIG_MODULES +void unwind_module_init(struct module *mod, void *_orc_ip, size_t orc_ip_size, + void *_orc, size_t orc_size) +{ + int *orc_ip = _orc_ip; + struct orc_entry *orc = _orc; + unsigned int num_entries = orc_ip_size / sizeof(int); + + WARN_ON_ONCE(orc_ip_size % sizeof(int) != 0 || + orc_size % sizeof(*orc) != 0 || + num_entries != orc_size / sizeof(*orc)); + + /* + * The 'cur_orc_*' globals allow the orc_sort_swap() callback to + * associate an .orc_unwind_ip table entry with its corresponding + * .orc_unwind entry so they can both be swapped. + */ + mutex_lock(&sort_mutex); + cur_orc_ip_table = orc_ip; + cur_orc_table = orc; + sort(orc_ip, num_entries, sizeof(int), orc_sort_cmp, orc_sort_swap); + mutex_unlock(&sort_mutex); + + mod->arch.orc_unwind_ip = orc_ip; + mod->arch.orc_unwind = orc; + mod->arch.num_orcs = num_entries; +} +#endif + +void __init unwind_init(void) +{ + size_t orc_ip_size = (void *)__stop_orc_unwind_ip - (void *)__start_orc_unwind_ip; + size_t orc_size = (void *)__stop_orc_unwind - (void *)__start_orc_unwind; + size_t num_entries = orc_ip_size / sizeof(int); + struct orc_entry *orc; + int i; + + if (!num_entries || orc_ip_size % sizeof(int) != 0 || + orc_size % sizeof(struct orc_entry) != 0 || + num_entries != orc_size / sizeof(struct orc_entry)) { + orc_warn("WARNING: Bad or missing .orc_unwind table. Disabling unwinder.\n"); + return; + } + + /* Sort the .orc_unwind and .orc_unwind_ip tables: */ + sort(__start_orc_unwind_ip, num_entries, sizeof(int), orc_sort_cmp, + orc_sort_swap); + + /* Initialize the fast lookup table: */ + lookup_num_blocks = orc_lookup_end - orc_lookup; + for (i = 0; i < lookup_num_blocks-1; i++) { + orc = __orc_find(__start_orc_unwind_ip, __start_orc_unwind, + num_entries, + LOOKUP_START_IP + (LOOKUP_BLOCK_SIZE * i)); + if (!orc) { + orc_warn("WARNING: Corrupt .orc_unwind table. Disabling unwinder.\n"); + return; + } + + orc_lookup[i] = orc - __start_orc_unwind; + } + + /* Initialize the ending block: */ + orc = __orc_find(__start_orc_unwind_ip, __start_orc_unwind, num_entries, + LOOKUP_STOP_IP); + if (!orc) { + orc_warn("WARNING: Corrupt .orc_unwind table. Disabling unwinder.\n"); + return; + } + orc_lookup[lookup_num_blocks-1] = orc - __start_orc_unwind; + + orc_init = true; +} + +unsigned long unwind_get_return_address(struct unwind_state *state) +{ + if (unwind_done(state)) + return 0; + + return __kernel_text_address(state->ip) ? state->ip : 0; +} +EXPORT_SYMBOL_GPL(unwind_get_return_address); + +unsigned long *unwind_get_return_address_ptr(struct unwind_state *state) +{ + if (unwind_done(state)) + return NULL; + + if (state->regs) + return &state->regs->ip; + + if (state->sp) + return (unsigned long *)state->sp - 1; + + return NULL; +} + +static bool stack_access_ok(struct unwind_state *state, unsigned long addr, + size_t len) +{ + struct stack_info *info = &state->stack_info; + + /* + * If the address isn't on the current stack, switch to the next one. + * + * We may have to traverse multiple stacks to deal with the possibility + * that info->next_sp could point to an empty stack and the address + * could be on a subsequent stack. + */ + while (!on_stack(info, (void *)addr, len)) + if (get_stack_info(info->next_sp, state->task, info, + &state->stack_mask)) + return false; + + return true; +} + +static bool deref_stack_reg(struct unwind_state *state, unsigned long addr, + unsigned long *val) +{ + if (!stack_access_ok(state, addr, sizeof(long))) + return false; + + *val = READ_ONCE_TASK_STACK(state->task, *(unsigned long *)addr); + return true; +} + +#define REGS_SIZE (sizeof(struct pt_regs)) +#define SP_OFFSET (offsetof(struct pt_regs, sp)) +#define IRET_REGS_SIZE (REGS_SIZE - offsetof(struct pt_regs, ip)) +#define IRET_SP_OFFSET (SP_OFFSET - offsetof(struct pt_regs, ip)) + +static bool deref_stack_regs(struct unwind_state *state, unsigned long addr, + unsigned long *ip, unsigned long *sp, bool full) +{ + size_t regs_size = full ? REGS_SIZE : IRET_REGS_SIZE; + size_t sp_offset = full ? SP_OFFSET : IRET_SP_OFFSET; + struct pt_regs *regs = (struct pt_regs *)(addr + regs_size - REGS_SIZE); + + if (IS_ENABLED(CONFIG_X86_64)) { + if (!stack_access_ok(state, addr, regs_size)) + return false; + + *ip = regs->ip; + *sp = regs->sp; + + return true; + } + + if (!stack_access_ok(state, addr, sp_offset)) + return false; + + *ip = regs->ip; + + if (user_mode(regs)) { + if (!stack_access_ok(state, addr + sp_offset, + REGS_SIZE - SP_OFFSET)) + return false; + + *sp = regs->sp; + } else + *sp = (unsigned long)®s->sp; + + return true; +} + +bool unwind_next_frame(struct unwind_state *state) +{ + unsigned long ip_p, sp, orig_ip, prev_sp = state->sp; + enum stack_type prev_type = state->stack_info.type; + struct orc_entry *orc; + struct pt_regs *ptregs; + bool indirect = false; + + if (unwind_done(state)) + return false; + + /* Don't let modules unload while we're reading their ORC data. */ + preempt_disable(); + + /* Have we reached the end? */ + if (state->regs && user_mode(state->regs)) + goto done; + + /* + * Find the orc_entry associated with the text address. + * + * Decrement call return addresses by one so they work for sibling + * calls and calls to noreturn functions. + */ + orc = orc_find(state->signal ? state->ip : state->ip - 1); + if (!orc || orc->sp_reg == ORC_REG_UNDEFINED) + goto done; + orig_ip = state->ip; + + /* Find the previous frame's stack: */ + switch (orc->sp_reg) { + case ORC_REG_SP: + sp = state->sp + orc->sp_offset; + break; + + case ORC_REG_BP: + sp = state->bp + orc->sp_offset; + break; + + case ORC_REG_SP_INDIRECT: + sp = state->sp + orc->sp_offset; + indirect = true; + break; + + case ORC_REG_BP_INDIRECT: + sp = state->bp + orc->sp_offset; + indirect = true; + break; + + case ORC_REG_R10: + if (!state->regs || !state->full_regs) { + orc_warn("missing regs for base reg R10 at ip %p\n", + (void *)state->ip); + goto done; + } + sp = state->regs->r10; + break; + + case ORC_REG_R13: + if (!state->regs || !state->full_regs) { + orc_warn("missing regs for base reg R13 at ip %p\n", + (void *)state->ip); + goto done; + } + sp = state->regs->r13; + break; + + case ORC_REG_DI: + if (!state->regs || !state->full_regs) { + orc_warn("missing regs for base reg DI at ip %p\n", + (void *)state->ip); + goto done; + } + sp = state->regs->di; + break; + + case ORC_REG_DX: + if (!state->regs || !state->full_regs) { + orc_warn("missing regs for base reg DX at ip %p\n", + (void *)state->ip); + goto done; + } + sp = state->regs->dx; + break; + + default: + orc_warn("unknown SP base reg %d for ip %p\n", + orc->sp_reg, (void *)state->ip); + goto done; + } + + if (indirect) { + if (!deref_stack_reg(state, sp, &sp)) + goto done; + } + + /* Find IP, SP and possibly regs: */ + switch (orc->type) { + case ORC_TYPE_CALL: + ip_p = sp - sizeof(long); + + if (!deref_stack_reg(state, ip_p, &state->ip)) + goto done; + + state->ip = ftrace_graph_ret_addr(state->task, &state->graph_idx, + state->ip, (void *)ip_p); + + state->sp = sp; + state->regs = NULL; + state->signal = false; + break; + + case ORC_TYPE_REGS: + if (!deref_stack_regs(state, sp, &state->ip, &state->sp, true)) { + orc_warn("can't dereference registers at %p for ip %p\n", + (void *)sp, (void *)orig_ip); + goto done; + } + + state->regs = (struct pt_regs *)sp; + state->full_regs = true; + state->signal = true; + break; + + case ORC_TYPE_REGS_IRET: + if (!deref_stack_regs(state, sp, &state->ip, &state->sp, false)) { + orc_warn("can't dereference iret registers at %p for ip %p\n", + (void *)sp, (void *)orig_ip); + goto done; + } + + ptregs = container_of((void *)sp, struct pt_regs, ip); + if ((unsigned long)ptregs >= prev_sp && + on_stack(&state->stack_info, ptregs, REGS_SIZE)) { + state->regs = ptregs; + state->full_regs = false; + } else + state->regs = NULL; + + state->signal = true; + break; + + default: + orc_warn("unknown .orc_unwind entry type %d\n", orc->type); + break; + } + + /* Find BP: */ + switch (orc->bp_reg) { + case ORC_REG_UNDEFINED: + if (state->regs && state->full_regs) + state->bp = state->regs->bp; + break; + + case ORC_REG_PREV_SP: + if (!deref_stack_reg(state, sp + orc->bp_offset, &state->bp)) + goto done; + break; + + case ORC_REG_BP: + if (!deref_stack_reg(state, state->bp + orc->bp_offset, &state->bp)) + goto done; + break; + + default: + orc_warn("unknown BP base reg %d for ip %p\n", + orc->bp_reg, (void *)orig_ip); + goto done; + } + + /* Prevent a recursive loop due to bad ORC data: */ + if (state->stack_info.type == prev_type && + on_stack(&state->stack_info, (void *)state->sp, sizeof(long)) && + state->sp <= prev_sp) { + orc_warn("stack going in the wrong direction? ip=%p\n", + (void *)orig_ip); + goto done; + } + + preempt_enable(); + return true; + +done: + preempt_enable(); + state->stack_info.type = STACK_TYPE_UNKNOWN; + return false; +} +EXPORT_SYMBOL_GPL(unwind_next_frame); + +void __unwind_start(struct unwind_state *state, struct task_struct *task, + struct pt_regs *regs, unsigned long *first_frame) +{ + memset(state, 0, sizeof(*state)); + state->task = task; + + /* + * Refuse to unwind the stack of a task while it's executing on another + * CPU. This check is racy, but that's ok: the unwinder has other + * checks to prevent it from going off the rails. + */ + if (task_on_another_cpu(task)) + goto done; + + if (regs) { + if (user_mode(regs)) + goto done; + + state->ip = regs->ip; + state->sp = kernel_stack_pointer(regs); + state->bp = regs->bp; + state->regs = regs; + state->full_regs = true; + state->signal = true; + + } else if (task == current) { + asm volatile("lea (%%rip), %0\n\t" + "mov %%rsp, %1\n\t" + "mov %%rbp, %2\n\t" + : "=r" (state->ip), "=r" (state->sp), + "=r" (state->bp)); + + } else { + struct inactive_task_frame *frame = (void *)task->thread.sp; + + state->sp = task->thread.sp; + state->bp = READ_ONCE_NOCHECK(frame->bp); + state->ip = READ_ONCE_NOCHECK(frame->ret_addr); + } + + if (get_stack_info((unsigned long *)state->sp, state->task, + &state->stack_info, &state->stack_mask)) + return; + + /* + * The caller can provide the address of the first frame directly + * (first_frame) or indirectly (regs->sp) to indicate which stack frame + * to start unwinding at. Skip ahead until we reach it. + */ + + /* When starting from regs, skip the regs frame: */ + if (regs) { + unwind_next_frame(state); + return; + } + + /* Otherwise, skip ahead to the user-specified starting frame: */ + while (!unwind_done(state) && + (!on_stack(&state->stack_info, first_frame, sizeof(long)) || + state->sp <= (unsigned long)first_frame)) + unwind_next_frame(state); + + return; + +done: + state->stack_info.type = STACK_TYPE_UNKNOWN; + return; +} +EXPORT_SYMBOL_GPL(__unwind_start); diff --git a/arch/x86/kernel/vmlinux.lds.S b/arch/x86/kernel/vmlinux.lds.S index c8a3b61be0aa..f05f00acac89 100644 --- a/arch/x86/kernel/vmlinux.lds.S +++ b/arch/x86/kernel/vmlinux.lds.S @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -148,6 +149,8 @@ SECTIONS BUG_TABLE + ORC_UNWIND_TABLE + . = ALIGN(PAGE_SIZE); __vvar_page = .; diff --git a/include/asm-generic/vmlinux.lds.h b/include/asm-generic/vmlinux.lds.h index da0be9a8d1de..fffc9bdae025 100644 --- a/include/asm-generic/vmlinux.lds.h +++ b/include/asm-generic/vmlinux.lds.h @@ -680,6 +680,31 @@ #define BUG_TABLE #endif +#ifdef CONFIG_ORC_UNWINDER +#define ORC_UNWIND_TABLE \ + . = ALIGN(4); \ + .orc_unwind_ip : AT(ADDR(.orc_unwind_ip) - LOAD_OFFSET) { \ + VMLINUX_SYMBOL(__start_orc_unwind_ip) = .; \ + KEEP(*(.orc_unwind_ip)) \ + VMLINUX_SYMBOL(__stop_orc_unwind_ip) = .; \ + } \ + . = ALIGN(6); \ + .orc_unwind : AT(ADDR(.orc_unwind) - LOAD_OFFSET) { \ + VMLINUX_SYMBOL(__start_orc_unwind) = .; \ + KEEP(*(.orc_unwind)) \ + VMLINUX_SYMBOL(__stop_orc_unwind) = .; \ + } \ + . = ALIGN(4); \ + .orc_lookup : AT(ADDR(.orc_lookup) - LOAD_OFFSET) { \ + VMLINUX_SYMBOL(orc_lookup) = .; \ + . += (((SIZEOF(.text) + LOOKUP_BLOCK_SIZE - 1) / \ + LOOKUP_BLOCK_SIZE) + 1) * 4; \ + VMLINUX_SYMBOL(orc_lookup_end) = .; \ + } +#else +#define ORC_UNWIND_TABLE +#endif + #ifdef CONFIG_PM_TRACE #define TRACEDATA \ . = ALIGN(4); \ @@ -866,7 +891,7 @@ DATA_DATA \ CONSTRUCTORS \ } \ - BUG_TABLE + BUG_TABLE \ #define INIT_TEXT_SECTION(inittext_align) \ . = ALIGN(inittext_align); \ diff --git a/lib/Kconfig.debug b/lib/Kconfig.debug index 98fe715522e8..0f0d019ffb99 100644 --- a/lib/Kconfig.debug +++ b/lib/Kconfig.debug @@ -374,6 +374,9 @@ config STACK_VALIDATION pointers (if CONFIG_FRAME_POINTER is enabled). This helps ensure that runtime stack traces are more reliable. + This is also a prerequisite for generation of ORC unwind data, which + is needed for CONFIG_ORC_UNWINDER. + For more information, see tools/objtool/Documentation/stack-validation.txt. diff --git a/scripts/Makefile.build b/scripts/Makefile.build index 854608d42e85..80ed14809a05 100644 --- a/scripts/Makefile.build +++ b/scripts/Makefile.build @@ -258,7 +258,8 @@ ifneq ($(SKIP_STACK_VALIDATION),1) __objtool_obj := $(objtree)/tools/objtool/objtool -objtool_args = check +objtool_args = $(if $(CONFIG_ORC_UNWINDER),orc generate,check) + ifndef CONFIG_FRAME_POINTER objtool_args += --no-fp endif @@ -279,6 +280,11 @@ objtool_obj = $(if $(patsubst y%,, \ endif # SKIP_STACK_VALIDATION endif # CONFIG_STACK_VALIDATION +# Rebuild all objects when objtool changes, or is enabled/disabled. +objtool_dep = $(objtool_obj) \ + $(wildcard include/config/orc/unwinder.h \ + include/config/stack/validation.h) + define rule_cc_o_c $(call echo-cmd,checksrc) $(cmd_checksrc) \ $(call cmd_and_fixdep,cc_o_c) \ @@ -301,13 +307,13 @@ cmd_undef_syms = echo endif # Built-in and composite module parts -$(obj)/%.o: $(src)/%.c $(recordmcount_source) $(objtool_obj) FORCE +$(obj)/%.o: $(src)/%.c $(recordmcount_source) $(objtool_dep) FORCE $(call cmd,force_checksrc) $(call if_changed_rule,cc_o_c) # Single-part modules are special since we need to mark them in $(MODVERDIR) -$(single-used-m): $(obj)/%.o: $(src)/%.c $(recordmcount_source) $(objtool_obj) FORCE +$(single-used-m): $(obj)/%.o: $(src)/%.c $(recordmcount_source) $(objtool_dep) FORCE $(call cmd,force_checksrc) $(call if_changed_rule,cc_o_c) @{ echo $(@:.o=.ko); echo $@; \ @@ -402,7 +408,7 @@ cmd_modversions_S = \ endif endif -$(obj)/%.o: $(src)/%.S $(objtool_obj) FORCE +$(obj)/%.o: $(src)/%.S $(objtool_dep) FORCE $(call if_changed_rule,as_o_S) targets += $(real-objs-y) $(real-objs-m) $(lib-y) -- cgit v1.2.3 From a34a766ff96d9e88572e35a45066279e40a85d84 Mon Sep 17 00:00:00 2001 From: Josh Poimboeuf Date: Mon, 24 Jul 2017 18:36:58 -0500 Subject: x86/kconfig: Make it easier to switch to the new ORC unwinder A couple of Kconfig changes which make it much easier to switch to the new CONFIG_ORC_UNWINDER: 1) Remove x86 dependencies on CONFIG_FRAME_POINTER for lockdep, latencytop, and fault injection. x86 has a 'guess' unwinder which just scans the stack for kernel text addresses. It's not 100% accurate but in many cases it's good enough. This allows those users who don't want the text overhead of the frame pointer or ORC unwinders to still use these features. More importantly, this also makes it much more straightforward to disable frame pointers. 2) Make CONFIG_ORC_UNWINDER depend on !CONFIG_FRAME_POINTER. While it would be possible to have both enabled, it doesn't really make sense to do so. So enforce a sane configuration to prevent the user from making a dumb mistake. With these changes, when you disable CONFIG_FRAME_POINTER, "make oldconfig" will ask if you want to enable CONFIG_ORC_UNWINDER. Signed-off-by: Josh Poimboeuf Cc: Andy Lutomirski Cc: Borislav Petkov Cc: Brian Gerst Cc: Denys Vlasenko Cc: H. Peter Anvin Cc: Jiri Slaby Cc: Linus Torvalds Cc: Mike Galbraith Cc: Peter Zijlstra Cc: Thomas Gleixner Cc: live-patching@vger.kernel.org Link: http://lkml.kernel.org/r/9985fb91ce5005fe33ea5cc2a20f14bd33c61d03.1500938583.git.jpoimboe@redhat.com Signed-off-by: Ingo Molnar --- arch/x86/Kconfig.debug | 7 +++---- lib/Kconfig.debug | 6 +++--- 2 files changed, 6 insertions(+), 7 deletions(-) (limited to 'lib') diff --git a/arch/x86/Kconfig.debug b/arch/x86/Kconfig.debug index dc10ec6ead6e..268a318f9439 100644 --- a/arch/x86/Kconfig.debug +++ b/arch/x86/Kconfig.debug @@ -357,7 +357,7 @@ config PUNIT_ATOM_DEBUG config ORC_UNWINDER bool "ORC unwinder" - depends on X86_64 + depends on X86_64 && !FRAME_POINTER select STACK_VALIDATION ---help--- This option enables the ORC (Oops Rewind Capability) unwinder for @@ -365,9 +365,8 @@ config ORC_UNWINDER a simplified version of the DWARF Call Frame Information standard. This unwinder is more accurate across interrupt entry frames than the - frame pointer unwinder. It can also enable a 5-10% performance - improvement across the entire kernel if CONFIG_FRAME_POINTER is - disabled. + frame pointer unwinder. It also enables a 5-10% performance + improvement across the entire kernel compared to frame pointers. Enabling this option will increase the kernel's runtime memory usage by roughly 2-4MB, depending on your kernel config. diff --git a/lib/Kconfig.debug b/lib/Kconfig.debug index 0f0d019ffb99..32a48e739e26 100644 --- a/lib/Kconfig.debug +++ b/lib/Kconfig.debug @@ -1124,7 +1124,7 @@ config LOCKDEP bool depends on DEBUG_KERNEL && TRACE_IRQFLAGS_SUPPORT && STACKTRACE_SUPPORT && LOCKDEP_SUPPORT select STACKTRACE - select FRAME_POINTER if !MIPS && !PPC && !ARM_UNWIND && !S390 && !MICROBLAZE && !ARC && !SCORE + select FRAME_POINTER if !MIPS && !PPC && !ARM_UNWIND && !S390 && !MICROBLAZE && !ARC && !SCORE && !X86 select KALLSYMS select KALLSYMS_ALL @@ -1543,7 +1543,7 @@ config FAULT_INJECTION_STACKTRACE_FILTER depends on FAULT_INJECTION_DEBUG_FS && STACKTRACE_SUPPORT depends on !X86_64 select STACKTRACE - select FRAME_POINTER if !MIPS && !PPC && !S390 && !MICROBLAZE && !ARM_UNWIND && !ARC && !SCORE + select FRAME_POINTER if !MIPS && !PPC && !S390 && !MICROBLAZE && !ARM_UNWIND && !ARC && !SCORE && !X86 help Provide stacktrace filter for fault-injection capabilities @@ -1552,7 +1552,7 @@ config LATENCYTOP depends on DEBUG_KERNEL depends on STACKTRACE_SUPPORT depends on PROC_FS - select FRAME_POINTER if !MIPS && !PPC && !S390 && !MICROBLAZE && !ARM_UNWIND && !ARC + select FRAME_POINTER if !MIPS && !PPC && !S390 && !MICROBLAZE && !ARM_UNWIND && !ARC && !X86 select KALLSYMS select KALLSYMS_ALL select STACKTRACE -- cgit v1.2.3 From 3acdfd280fe7d807237f2cb7a09d6f8f7f1b484f Mon Sep 17 00:00:00 2001 From: Jeff Layton Date: Mon, 24 Jul 2017 06:22:15 -0400 Subject: errseq: rename __errseq_set to errseq_set Nothing calls this wrapper anymore, so just remove it and rename the old function to get rid of the double underscore prefix. Signed-off-by: Jeff Layton --- include/linux/errseq.h | 9 +-------- lib/errseq.c | 17 +++++++---------- mm/filemap.c | 2 +- 3 files changed, 9 insertions(+), 19 deletions(-) (limited to 'lib') diff --git a/include/linux/errseq.h b/include/linux/errseq.h index 9e0d444ac88d..784f0860527b 100644 --- a/include/linux/errseq.h +++ b/include/linux/errseq.h @@ -5,14 +5,7 @@ typedef u32 errseq_t; -errseq_t __errseq_set(errseq_t *eseq, int err); -static inline void errseq_set(errseq_t *eseq, int err) -{ - /* Optimize for the common case of no error */ - if (unlikely(err)) - __errseq_set(eseq, err); -} - +errseq_t errseq_set(errseq_t *eseq, int err); errseq_t errseq_sample(errseq_t *eseq); int errseq_check(errseq_t *eseq, errseq_t since); int errseq_check_and_advance(errseq_t *eseq, errseq_t *since); diff --git a/lib/errseq.c b/lib/errseq.c index 841fa24e6e00..7b900c2a277a 100644 --- a/lib/errseq.c +++ b/lib/errseq.c @@ -41,23 +41,20 @@ #define ERRSEQ_CTR_INC (1 << (ERRSEQ_SHIFT + 1)) /** - * __errseq_set - set a errseq_t for later reporting + * errseq_set - set a errseq_t for later reporting * @eseq: errseq_t field that should be set - * @err: error to set + * @err: error to set (must be between -1 and -MAX_ERRNO) * * This function sets the error in *eseq, and increments the sequence counter * if the last sequence was sampled at some point in the past. * * Any error set will always overwrite an existing error. * - * Most callers will want to use the errseq_set inline wrapper to efficiently - * handle the common case where err is 0. - * - * We do return an errseq_t here, primarily for debugging purposes. The return - * value should not be used as a previously sampled value in later calls as it - * will not have the SEEN flag set. + * We do return the latest value here, primarily for debugging purposes. The + * return value should not be used as a previously sampled value in later calls + * as it will not have the SEEN flag set. */ -errseq_t __errseq_set(errseq_t *eseq, int err) +errseq_t errseq_set(errseq_t *eseq, int err) { errseq_t cur, old; @@ -107,7 +104,7 @@ errseq_t __errseq_set(errseq_t *eseq, int err) } return cur; } -EXPORT_SYMBOL(__errseq_set); +EXPORT_SYMBOL(errseq_set); /** * errseq_sample - grab current errseq_t value diff --git a/mm/filemap.c b/mm/filemap.c index a49702445ce0..e1cca770688f 100644 --- a/mm/filemap.c +++ b/mm/filemap.c @@ -589,7 +589,7 @@ EXPORT_SYMBOL(filemap_write_and_wait_range); void __filemap_set_wb_err(struct address_space *mapping, int err) { - errseq_t eseq = __errseq_set(&mapping->wb_err, err); + errseq_t eseq = errseq_set(&mapping->wb_err, err); trace_filemap_set_wb_err(mapping, eseq); } -- cgit v1.2.3 From 64c83d837329531252a1a0f0dfdd4fd607e1d8e9 Mon Sep 17 00:00:00 2001 From: Jamal Hadi Salim Date: Sun, 30 Jul 2017 13:24:49 -0400 Subject: net netlink: Add new type NLA_BITFIELD32 Generic bitflags attribute content sent to the kernel by user. With this netlink attr type the user can either set or unset a flag in the kernel. The value is a bitmap that defines the bit values being set The selector is a bitmask that defines which value bit is to be considered. A check is made to ensure the rules that a kernel subsystem always conforms to bitflags the kernel already knows about. i.e if the user tries to set a bit flag that is not understood then the _it will be rejected_. In the most basic form, the user specifies the attribute policy as: [ATTR_GOO] = { .type = NLA_BITFIELD32, .validation_data = &myvalidflags }, where myvalidflags is the bit mask of the flags the kernel understands. If the user _does not_ provide myvalidflags then the attribute will also be rejected. Examples: value = 0x0, and selector = 0x1 implies we are selecting bit 1 and we want to set its value to 0. value = 0x2, and selector = 0x2 implies we are selecting bit 2 and we want to set its value to 1. Suggested-by: Jiri Pirko Signed-off-by: Jamal Hadi Salim Signed-off-by: David S. Miller --- include/net/netlink.h | 16 ++++++++++++++++ include/uapi/linux/netlink.h | 17 +++++++++++++++++ lib/nlattr.c | 30 ++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+) (limited to 'lib') diff --git a/include/net/netlink.h b/include/net/netlink.h index ef8e6c3a80a6..82dd298b40c7 100644 --- a/include/net/netlink.h +++ b/include/net/netlink.h @@ -178,6 +178,7 @@ enum { NLA_S16, NLA_S32, NLA_S64, + NLA_BITFIELD32, __NLA_TYPE_MAX, }; @@ -206,6 +207,7 @@ enum { * NLA_MSECS Leaving the length field zero will verify the * given type fits, using it verifies minimum length * just like "All other" + * NLA_BITFIELD32 A 32-bit bitmap/bitselector attribute * All other Minimum length of attribute payload * * Example: @@ -213,11 +215,13 @@ enum { * [ATTR_FOO] = { .type = NLA_U16 }, * [ATTR_BAR] = { .type = NLA_STRING, .len = BARSIZ }, * [ATTR_BAZ] = { .len = sizeof(struct mystruct) }, + * [ATTR_GOO] = { .type = NLA_BITFIELD32, .validation_data = &myvalidflags }, * }; */ struct nla_policy { u16 type; u16 len; + void *validation_data; }; /** @@ -1202,6 +1206,18 @@ static inline struct in6_addr nla_get_in6_addr(const struct nlattr *nla) return tmp; } +/** + * nla_get_bitfield32 - return payload of 32 bitfield attribute + * @nla: nla_bitfield32 attribute + */ +static inline struct nla_bitfield32 nla_get_bitfield32(const struct nlattr *nla) +{ + struct nla_bitfield32 tmp; + + nla_memcpy(&tmp, nla, sizeof(tmp)); + return tmp; +} + /** * nla_memdup - duplicate attribute memory (kmemdup) * @src: netlink attribute to duplicate from diff --git a/include/uapi/linux/netlink.h b/include/uapi/linux/netlink.h index f86127a46cfc..f4fc9c9e123d 100644 --- a/include/uapi/linux/netlink.h +++ b/include/uapi/linux/netlink.h @@ -226,5 +226,22 @@ struct nlattr { #define NLA_ALIGN(len) (((len) + NLA_ALIGNTO - 1) & ~(NLA_ALIGNTO - 1)) #define NLA_HDRLEN ((int) NLA_ALIGN(sizeof(struct nlattr))) +/* Generic 32 bitflags attribute content sent to the kernel. + * + * The value is a bitmap that defines the values being set + * The selector is a bitmask that defines which value is legit + * + * Examples: + * value = 0x0, and selector = 0x1 + * implies we are selecting bit 1 and we want to set its value to 0. + * + * value = 0x2, and selector = 0x2 + * implies we are selecting bit 2 and we want to set its value to 1. + * + */ +struct nla_bitfield32 { + __u32 value; + __u32 selector; +}; #endif /* _UAPI__LINUX_NETLINK_H */ diff --git a/lib/nlattr.c b/lib/nlattr.c index fb52435be42d..ee79b7a3c6b0 100644 --- a/lib/nlattr.c +++ b/lib/nlattr.c @@ -27,6 +27,30 @@ static const u8 nla_attr_minlen[NLA_TYPE_MAX+1] = { [NLA_S64] = sizeof(s64), }; +static int validate_nla_bitfield32(const struct nlattr *nla, + u32 *valid_flags_allowed) +{ + const struct nla_bitfield32 *bf = nla_data(nla); + u32 *valid_flags_mask = valid_flags_allowed; + + if (!valid_flags_allowed) + return -EINVAL; + + /*disallow invalid bit selector */ + if (bf->selector & ~*valid_flags_mask) + return -EINVAL; + + /*disallow invalid bit values */ + if (bf->value & ~*valid_flags_mask) + return -EINVAL; + + /*disallow valid bit values that are not selected*/ + if (bf->value & ~bf->selector) + return -EINVAL; + + return 0; +} + static int validate_nla(const struct nlattr *nla, int maxtype, const struct nla_policy *policy) { @@ -46,6 +70,12 @@ static int validate_nla(const struct nlattr *nla, int maxtype, return -ERANGE; break; + case NLA_BITFIELD32: + if (attrlen != sizeof(struct nla_bitfield32)) + return -ERANGE; + + return validate_nla_bitfield32(nla, pt->validation_data); + case NLA_NUL_STRING: if (pt->len) minlen = min_t(int, attrlen, pt->len + 1); -- cgit v1.2.3 From 2cf0c8b3e6942ecafe6ebb1a6d0328a81641bf39 Mon Sep 17 00:00:00 2001 From: Phil Sutter Date: Thu, 27 Jul 2017 16:56:40 +0200 Subject: netlink: Introduce nla_strdup() This is similar to strdup() for netlink string attributes. Signed-off-by: Phil Sutter Signed-off-by: Pablo Neira Ayuso --- include/net/netlink.h | 1 + lib/nlattr.c | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+) (limited to 'lib') diff --git a/include/net/netlink.h b/include/net/netlink.h index ef8e6c3a80a6..c8c2eb5ae55e 100644 --- a/include/net/netlink.h +++ b/include/net/netlink.h @@ -247,6 +247,7 @@ int nla_parse(struct nlattr **tb, int maxtype, const struct nlattr *head, int nla_policy_len(const struct nla_policy *, int); struct nlattr *nla_find(const struct nlattr *head, int len, int attrtype); size_t nla_strlcpy(char *dst, const struct nlattr *nla, size_t dstsize); +char *nla_strdup(const struct nlattr *nla, gfp_t flags); int nla_memcpy(void *dest, const struct nlattr *src, int count); int nla_memcmp(const struct nlattr *nla, const void *data, size_t size); int nla_strcmp(const struct nlattr *nla, const char *str); diff --git a/lib/nlattr.c b/lib/nlattr.c index fb52435be42d..f13013f7e21a 100644 --- a/lib/nlattr.c +++ b/lib/nlattr.c @@ -271,6 +271,30 @@ size_t nla_strlcpy(char *dst, const struct nlattr *nla, size_t dstsize) } EXPORT_SYMBOL(nla_strlcpy); +/** + * nla_strdup - Copy string attribute payload into a newly allocated buffer + * @nla: attribute to copy the string from + * @flags: the type of memory to allocate (see kmalloc). + * + * Returns a pointer to the allocated buffer or NULL on error. + */ +char *nla_strdup(const struct nlattr *nla, gfp_t flags) +{ + size_t srclen = nla_len(nla); + char *src = nla_data(nla), *dst; + + if (srclen > 0 && src[srclen - 1] == '\0') + srclen--; + + dst = kmalloc(srclen + 1, flags); + if (dst != NULL) { + memcpy(dst, src, srclen); + dst[srclen] = '\0'; + } + return dst; +} +EXPORT_SYMBOL(nla_strdup); + /** * nla_memcpy - Copy a netlink attribute into another memory area * @dest: where to copy to memcpy -- cgit v1.2.3 From 35129dde88afad07f54b332d4f9eda2d254b80f2 Mon Sep 17 00:00:00 2001 From: Ard Biesheuvel Date: Thu, 13 Jul 2017 18:16:00 +0100 Subject: md/raid6: use faster multiplication for ARM NEON delta syndrome The P/Q left side optimization in the delta syndrome simply involves repeatedly multiplying a value by polynomial 'x' in GF(2^8). Given that 'x * x * x * x' equals 'x^4' even in the polynomial world, we can accelerate this substantially by performing up to 4 such operations at once, using the NEON instructions for polynomial multiplication. Results on a Cortex-A57 running in 64-bit mode: Before: ------- raid6: neonx1 xor() 1680 MB/s raid6: neonx2 xor() 2286 MB/s raid6: neonx4 xor() 3162 MB/s raid6: neonx8 xor() 3389 MB/s After: ------ raid6: neonx1 xor() 2281 MB/s raid6: neonx2 xor() 3362 MB/s raid6: neonx4 xor() 3787 MB/s raid6: neonx8 xor() 4239 MB/s While we're at it, simplify MASK() by using a signed shift rather than a vector compare involving a temp register. Signed-off-by: Ard Biesheuvel Signed-off-by: Catalin Marinas --- lib/raid6/neon.uc | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) (limited to 'lib') diff --git a/lib/raid6/neon.uc b/lib/raid6/neon.uc index 4fa51b761dd0..d5242f544551 100644 --- a/lib/raid6/neon.uc +++ b/lib/raid6/neon.uc @@ -46,8 +46,12 @@ static inline unative_t SHLBYTE(unative_t v) */ static inline unative_t MASK(unative_t v) { - const uint8x16_t temp = NBYTES(0); - return (unative_t)vcltq_s8((int8x16_t)v, (int8x16_t)temp); + return (unative_t)vshrq_n_s8((int8x16_t)v, 7); +} + +static inline unative_t PMUL(unative_t v, unative_t u) +{ + return (unative_t)vmulq_p8((poly8x16_t)v, (poly8x16_t)u); } void raid6_neon$#_gen_syndrome_real(int disks, unsigned long bytes, void **ptrs) @@ -110,7 +114,30 @@ void raid6_neon$#_xor_syndrome_real(int disks, int start, int stop, wq$$ = veorq_u8(w1$$, wd$$); } /* P/Q left side optimization */ - for ( z = start-1 ; z >= 0 ; z-- ) { + for ( z = start-1 ; z >= 3 ; z -= 4 ) { + w2$$ = vshrq_n_u8(wq$$, 4); + w1$$ = vshlq_n_u8(wq$$, 4); + + w2$$ = PMUL(w2$$, x1d); + wq$$ = veorq_u8(w1$$, w2$$); + } + + switch (z) { + case 2: + w2$$ = vshrq_n_u8(wq$$, 5); + w1$$ = vshlq_n_u8(wq$$, 3); + + w2$$ = PMUL(w2$$, x1d); + wq$$ = veorq_u8(w1$$, w2$$); + break; + case 1: + w2$$ = vshrq_n_u8(wq$$, 6); + w1$$ = vshlq_n_u8(wq$$, 2); + + w2$$ = PMUL(w2$$, x1d); + wq$$ = veorq_u8(w1$$, w2$$); + break; + case 0: w2$$ = MASK(wq$$); w1$$ = SHLBYTE(wq$$); -- cgit v1.2.3 From 6ec4e2514decd6fb4782a9364fa71d6244d05af4 Mon Sep 17 00:00:00 2001 From: Ard Biesheuvel Date: Thu, 13 Jul 2017 18:16:01 +0100 Subject: md/raid6: implement recovery using ARM NEON intrinsics Provide a NEON accelerated implementation of the recovery algorithm, which supersedes the default byte-by-byte one. Signed-off-by: Ard Biesheuvel Signed-off-by: Catalin Marinas --- include/linux/raid/pq.h | 1 + lib/raid6/Makefile | 4 +- lib/raid6/algos.c | 3 ++ lib/raid6/recov_neon.c | 110 ++++++++++++++++++++++++++++++++++++++++ lib/raid6/recov_neon_inner.c | 117 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 234 insertions(+), 1 deletion(-) create mode 100644 lib/raid6/recov_neon.c create mode 100644 lib/raid6/recov_neon_inner.c (limited to 'lib') diff --git a/include/linux/raid/pq.h b/include/linux/raid/pq.h index 30f945329818..583cdd3d49ca 100644 --- a/include/linux/raid/pq.h +++ b/include/linux/raid/pq.h @@ -121,6 +121,7 @@ extern const struct raid6_recov_calls raid6_recov_ssse3; extern const struct raid6_recov_calls raid6_recov_avx2; extern const struct raid6_recov_calls raid6_recov_avx512; extern const struct raid6_recov_calls raid6_recov_s390xc; +extern const struct raid6_recov_calls raid6_recov_neon; extern const struct raid6_calls raid6_neonx1; extern const struct raid6_calls raid6_neonx2; diff --git a/lib/raid6/Makefile b/lib/raid6/Makefile index 3057011f5599..a93adf6dcfb2 100644 --- a/lib/raid6/Makefile +++ b/lib/raid6/Makefile @@ -5,7 +5,7 @@ raid6_pq-y += algos.o recov.o tables.o int1.o int2.o int4.o \ raid6_pq-$(CONFIG_X86) += recov_ssse3.o recov_avx2.o mmx.o sse1.o sse2.o avx2.o avx512.o recov_avx512.o raid6_pq-$(CONFIG_ALTIVEC) += altivec1.o altivec2.o altivec4.o altivec8.o -raid6_pq-$(CONFIG_KERNEL_MODE_NEON) += neon.o neon1.o neon2.o neon4.o neon8.o +raid6_pq-$(CONFIG_KERNEL_MODE_NEON) += neon.o neon1.o neon2.o neon4.o neon8.o recov_neon.o recov_neon_inner.o raid6_pq-$(CONFIG_TILEGX) += tilegx8.o raid6_pq-$(CONFIG_S390) += s390vx8.o recov_s390xc.o @@ -26,7 +26,9 @@ NEON_FLAGS := -ffreestanding ifeq ($(ARCH),arm) NEON_FLAGS += -mfloat-abi=softfp -mfpu=neon endif +CFLAGS_recov_neon_inner.o += $(NEON_FLAGS) ifeq ($(ARCH),arm64) +CFLAGS_REMOVE_recov_neon_inner.o += -mgeneral-regs-only CFLAGS_REMOVE_neon1.o += -mgeneral-regs-only CFLAGS_REMOVE_neon2.o += -mgeneral-regs-only CFLAGS_REMOVE_neon4.o += -mgeneral-regs-only diff --git a/lib/raid6/algos.c b/lib/raid6/algos.c index 7857049fd7d3..476994723258 100644 --- a/lib/raid6/algos.c +++ b/lib/raid6/algos.c @@ -112,6 +112,9 @@ const struct raid6_recov_calls *const raid6_recov_algos[] = { #endif #ifdef CONFIG_S390 &raid6_recov_s390xc, +#endif +#if defined(CONFIG_KERNEL_MODE_NEON) + &raid6_recov_neon, #endif &raid6_recov_intx1, NULL diff --git a/lib/raid6/recov_neon.c b/lib/raid6/recov_neon.c new file mode 100644 index 000000000000..eeb5c4065b92 --- /dev/null +++ b/lib/raid6/recov_neon.c @@ -0,0 +1,110 @@ +/* + * Copyright (C) 2012 Intel Corporation + * Copyright (C) 2017 Linaro Ltd. + * + * This program is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License + * as published by the Free Software Foundation; version 2 + * of the License. + */ + +#include + +#ifdef __KERNEL__ +#include +#else +#define kernel_neon_begin() +#define kernel_neon_end() +#define cpu_has_neon() (1) +#endif + +static int raid6_has_neon(void) +{ + return cpu_has_neon(); +} + +void __raid6_2data_recov_neon(int bytes, uint8_t *p, uint8_t *q, uint8_t *dp, + uint8_t *dq, const uint8_t *pbmul, + const uint8_t *qmul); + +void __raid6_datap_recov_neon(int bytes, uint8_t *p, uint8_t *q, uint8_t *dq, + const uint8_t *qmul); + +static void raid6_2data_recov_neon(int disks, size_t bytes, int faila, + int failb, void **ptrs) +{ + u8 *p, *q, *dp, *dq; + const u8 *pbmul; /* P multiplier table for B data */ + const u8 *qmul; /* Q multiplier table (for both) */ + + p = (u8 *)ptrs[disks - 2]; + q = (u8 *)ptrs[disks - 1]; + + /* + * Compute syndrome with zero for the missing data pages + * Use the dead data pages as temporary storage for + * delta p and delta q + */ + dp = (u8 *)ptrs[faila]; + ptrs[faila] = (void *)raid6_empty_zero_page; + ptrs[disks - 2] = dp; + dq = (u8 *)ptrs[failb]; + ptrs[failb] = (void *)raid6_empty_zero_page; + ptrs[disks - 1] = dq; + + raid6_call.gen_syndrome(disks, bytes, ptrs); + + /* Restore pointer table */ + ptrs[faila] = dp; + ptrs[failb] = dq; + ptrs[disks - 2] = p; + ptrs[disks - 1] = q; + + /* Now, pick the proper data tables */ + pbmul = raid6_vgfmul[raid6_gfexi[failb-faila]]; + qmul = raid6_vgfmul[raid6_gfinv[raid6_gfexp[faila] ^ + raid6_gfexp[failb]]]; + + kernel_neon_begin(); + __raid6_2data_recov_neon(bytes, p, q, dp, dq, pbmul, qmul); + kernel_neon_end(); +} + +static void raid6_datap_recov_neon(int disks, size_t bytes, int faila, + void **ptrs) +{ + u8 *p, *q, *dq; + const u8 *qmul; /* Q multiplier table */ + + p = (u8 *)ptrs[disks - 2]; + q = (u8 *)ptrs[disks - 1]; + + /* + * Compute syndrome with zero for the missing data page + * Use the dead data page as temporary storage for delta q + */ + dq = (u8 *)ptrs[faila]; + ptrs[faila] = (void *)raid6_empty_zero_page; + ptrs[disks - 1] = dq; + + raid6_call.gen_syndrome(disks, bytes, ptrs); + + /* Restore pointer table */ + ptrs[faila] = dq; + ptrs[disks - 1] = q; + + /* Now, pick the proper data tables */ + qmul = raid6_vgfmul[raid6_gfinv[raid6_gfexp[faila]]]; + + kernel_neon_begin(); + __raid6_datap_recov_neon(bytes, p, q, dq, qmul); + kernel_neon_end(); +} + +const struct raid6_recov_calls raid6_recov_neon = { + .data2 = raid6_2data_recov_neon, + .datap = raid6_datap_recov_neon, + .valid = raid6_has_neon, + .name = "neon", + .priority = 10, +}; diff --git a/lib/raid6/recov_neon_inner.c b/lib/raid6/recov_neon_inner.c new file mode 100644 index 000000000000..8cd20c9f834a --- /dev/null +++ b/lib/raid6/recov_neon_inner.c @@ -0,0 +1,117 @@ +/* + * Copyright (C) 2012 Intel Corporation + * Copyright (C) 2017 Linaro Ltd. + * + * This program is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License + * as published by the Free Software Foundation; version 2 + * of the License. + */ + +#include + +static const uint8x16_t x0f = { + 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, + 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, +}; + +#ifdef CONFIG_ARM +/* + * AArch32 does not provide this intrinsic natively because it does not + * implement the underlying instruction. AArch32 only provides a 64-bit + * wide vtbl.8 instruction, so use that instead. + */ +static uint8x16_t vqtbl1q_u8(uint8x16_t a, uint8x16_t b) +{ + union { + uint8x16_t val; + uint8x8x2_t pair; + } __a = { a }; + + return vcombine_u8(vtbl2_u8(__a.pair, vget_low_u8(b)), + vtbl2_u8(__a.pair, vget_high_u8(b))); +} +#endif + +void __raid6_2data_recov_neon(int bytes, uint8_t *p, uint8_t *q, uint8_t *dp, + uint8_t *dq, const uint8_t *pbmul, + const uint8_t *qmul) +{ + uint8x16_t pm0 = vld1q_u8(pbmul); + uint8x16_t pm1 = vld1q_u8(pbmul + 16); + uint8x16_t qm0 = vld1q_u8(qmul); + uint8x16_t qm1 = vld1q_u8(qmul + 16); + + /* + * while ( bytes-- ) { + * uint8_t px, qx, db; + * + * px = *p ^ *dp; + * qx = qmul[*q ^ *dq]; + * *dq++ = db = pbmul[px] ^ qx; + * *dp++ = db ^ px; + * p++; q++; + * } + */ + + while (bytes) { + uint8x16_t vx, vy, px, qx, db; + + px = veorq_u8(vld1q_u8(p), vld1q_u8(dp)); + vx = veorq_u8(vld1q_u8(q), vld1q_u8(dq)); + + vy = (uint8x16_t)vshrq_n_s16((int16x8_t)vx, 4); + vx = vqtbl1q_u8(qm0, vandq_u8(vx, x0f)); + vy = vqtbl1q_u8(qm1, vandq_u8(vy, x0f)); + qx = veorq_u8(vx, vy); + + vy = (uint8x16_t)vshrq_n_s16((int16x8_t)px, 4); + vx = vqtbl1q_u8(pm0, vandq_u8(px, x0f)); + vy = vqtbl1q_u8(pm1, vandq_u8(vy, x0f)); + vx = veorq_u8(vx, vy); + db = veorq_u8(vx, qx); + + vst1q_u8(dq, db); + vst1q_u8(dp, veorq_u8(db, px)); + + bytes -= 16; + p += 16; + q += 16; + dp += 16; + dq += 16; + } +} + +void __raid6_datap_recov_neon(int bytes, uint8_t *p, uint8_t *q, uint8_t *dq, + const uint8_t *qmul) +{ + uint8x16_t qm0 = vld1q_u8(qmul); + uint8x16_t qm1 = vld1q_u8(qmul + 16); + + /* + * while (bytes--) { + * *p++ ^= *dq = qmul[*q ^ *dq]; + * q++; dq++; + * } + */ + + while (bytes) { + uint8x16_t vx, vy; + + vx = veorq_u8(vld1q_u8(q), vld1q_u8(dq)); + + vy = (uint8x16_t)vshrq_n_s16((int16x8_t)vx, 4); + vx = vqtbl1q_u8(qm0, vandq_u8(vx, x0f)); + vy = vqtbl1q_u8(qm1, vandq_u8(vy, x0f)); + vx = veorq_u8(vx, vy); + vy = veorq_u8(vx, vld1q_u8(p)); + + vst1q_u8(dq, vx); + vst1q_u8(p, vy); + + bytes -= 16; + p += 16; + q += 16; + dq += 16; + } +} -- cgit v1.2.3 From 92b31a9af73b3a3fc801899335d6c47966351830 Mon Sep 17 00:00:00 2001 From: Daniel Borkmann Date: Thu, 10 Aug 2017 01:39:55 +0200 Subject: bpf: add BPF_J{LT,LE,SLT,SLE} instructions Currently, eBPF only understands BPF_JGT (>), BPF_JGE (>=), BPF_JSGT (s>), BPF_JSGE (s>=) instructions, this means that particularly *JLT/*JLE counterparts involving immediates need to be rewritten from e.g. X < [IMM] by swapping arguments into [IMM] > X, meaning the immediate first is required to be loaded into a register Y := [IMM], such that then we can compare with Y > X. Note that the destination operand is always required to be a register. This has the downside of having unnecessarily increased register pressure, meaning complex program would need to spill other registers temporarily to stack in order to obtain an unused register for the [IMM]. Loading to registers will thus also affect state pruning since we need to account for that register use and potentially those registers that had to be spilled/filled again. As a consequence slightly more stack space might have been used due to spilling, and BPF programs are a bit longer due to extra code involving the register load and potentially required spill/fills. Thus, add BPF_JLT (<), BPF_JLE (<=), BPF_JSLT (s<), BPF_JSLE (s<=) counterparts to the eBPF instruction set. Modifying LLVM to remove the NegateCC() workaround in a PoC patch at [1] and allowing it to also emit the new instructions resulted in cilium's BPF programs that are injected into the fast-path to have a reduced program length in the range of 2-3% (e.g. accumulated main and tail call sections from one of the object file reduced from 4864 to 4729 insns), reduced complexity in the range of 10-30% (e.g. accumulated sections reduced in one of the cases from 116432 to 88428 insns), and reduced stack usage in the range of 1-5% (e.g. accumulated sections from one of the object files reduced from 824 to 784b). The modification for LLVM will be incorporated in a backwards compatible way. Plan is for LLVM to have i) a target specific option to offer a possibility to explicitly enable the extension by the user (as we have with -m target specific extensions today for various CPU insns), and ii) have the kernel checked for presence of the extensions and enable them transparently when the user is selecting more aggressive options such as -march=native in a bpf target context. (Other frontends generating BPF byte code, e.g. ply can probe the kernel directly for its code generation.) [1] https://github.com/borkmann/llvm/tree/bpf-insns Signed-off-by: Daniel Borkmann Acked-by: Alexei Starovoitov Signed-off-by: David S. Miller --- Documentation/networking/filter.txt | 4 + include/uapi/linux/bpf.h | 5 + kernel/bpf/core.c | 60 ++++++ lib/test_bpf.c | 364 ++++++++++++++++++++++++++++++++++++ net/core/filter.c | 21 ++- tools/include/uapi/linux/bpf.h | 5 + 6 files changed, 455 insertions(+), 4 deletions(-) (limited to 'lib') diff --git a/Documentation/networking/filter.txt b/Documentation/networking/filter.txt index d0fdba7d66e2..6a0df8df6c43 100644 --- a/Documentation/networking/filter.txt +++ b/Documentation/networking/filter.txt @@ -906,6 +906,10 @@ If BPF_CLASS(code) == BPF_JMP, BPF_OP(code) is one of: BPF_JSGE 0x70 /* eBPF only: signed '>=' */ BPF_CALL 0x80 /* eBPF only: function call */ BPF_EXIT 0x90 /* eBPF only: function return */ + BPF_JLT 0xa0 /* eBPF only: unsigned '<' */ + BPF_JLE 0xb0 /* eBPF only: unsigned '<=' */ + BPF_JSLT 0xc0 /* eBPF only: signed '<' */ + BPF_JSLE 0xd0 /* eBPF only: signed '<=' */ So BPF_ADD | BPF_X | BPF_ALU means 32-bit addition in both classic BPF and eBPF. There are only two registers in classic BPF, so it means A += X. diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h index 1d06be1569b1..91da8371a2d0 100644 --- a/include/uapi/linux/bpf.h +++ b/include/uapi/linux/bpf.h @@ -30,9 +30,14 @@ #define BPF_FROM_LE BPF_TO_LE #define BPF_FROM_BE BPF_TO_BE +/* jmp encodings */ #define BPF_JNE 0x50 /* jump != */ +#define BPF_JLT 0xa0 /* LT is unsigned, '<' */ +#define BPF_JLE 0xb0 /* LE is unsigned, '<=' */ #define BPF_JSGT 0x60 /* SGT is signed '>', GT in x86 */ #define BPF_JSGE 0x70 /* SGE is signed '>=', GE in x86 */ +#define BPF_JSLT 0xc0 /* SLT is signed, '<' */ +#define BPF_JSLE 0xd0 /* SLE is signed, '<=' */ #define BPF_CALL 0x80 /* function call */ #define BPF_EXIT 0x90 /* function return */ diff --git a/kernel/bpf/core.c b/kernel/bpf/core.c index ad5f55922a13..c69e7f5bfde7 100644 --- a/kernel/bpf/core.c +++ b/kernel/bpf/core.c @@ -595,9 +595,13 @@ static int bpf_jit_blind_insn(const struct bpf_insn *from, case BPF_JMP | BPF_JEQ | BPF_K: case BPF_JMP | BPF_JNE | BPF_K: case BPF_JMP | BPF_JGT | BPF_K: + case BPF_JMP | BPF_JLT | BPF_K: case BPF_JMP | BPF_JGE | BPF_K: + case BPF_JMP | BPF_JLE | BPF_K: case BPF_JMP | BPF_JSGT | BPF_K: + case BPF_JMP | BPF_JSLT | BPF_K: case BPF_JMP | BPF_JSGE | BPF_K: + case BPF_JMP | BPF_JSLE | BPF_K: case BPF_JMP | BPF_JSET | BPF_K: /* Accommodate for extra offset in case of a backjump. */ off = from->off; @@ -833,12 +837,20 @@ static unsigned int ___bpf_prog_run(u64 *regs, const struct bpf_insn *insn, [BPF_JMP | BPF_JNE | BPF_K] = &&JMP_JNE_K, [BPF_JMP | BPF_JGT | BPF_X] = &&JMP_JGT_X, [BPF_JMP | BPF_JGT | BPF_K] = &&JMP_JGT_K, + [BPF_JMP | BPF_JLT | BPF_X] = &&JMP_JLT_X, + [BPF_JMP | BPF_JLT | BPF_K] = &&JMP_JLT_K, [BPF_JMP | BPF_JGE | BPF_X] = &&JMP_JGE_X, [BPF_JMP | BPF_JGE | BPF_K] = &&JMP_JGE_K, + [BPF_JMP | BPF_JLE | BPF_X] = &&JMP_JLE_X, + [BPF_JMP | BPF_JLE | BPF_K] = &&JMP_JLE_K, [BPF_JMP | BPF_JSGT | BPF_X] = &&JMP_JSGT_X, [BPF_JMP | BPF_JSGT | BPF_K] = &&JMP_JSGT_K, + [BPF_JMP | BPF_JSLT | BPF_X] = &&JMP_JSLT_X, + [BPF_JMP | BPF_JSLT | BPF_K] = &&JMP_JSLT_K, [BPF_JMP | BPF_JSGE | BPF_X] = &&JMP_JSGE_X, [BPF_JMP | BPF_JSGE | BPF_K] = &&JMP_JSGE_K, + [BPF_JMP | BPF_JSLE | BPF_X] = &&JMP_JSLE_X, + [BPF_JMP | BPF_JSLE | BPF_K] = &&JMP_JSLE_K, [BPF_JMP | BPF_JSET | BPF_X] = &&JMP_JSET_X, [BPF_JMP | BPF_JSET | BPF_K] = &&JMP_JSET_K, /* Program return */ @@ -1073,6 +1085,18 @@ out: CONT_JMP; } CONT; + JMP_JLT_X: + if (DST < SRC) { + insn += insn->off; + CONT_JMP; + } + CONT; + JMP_JLT_K: + if (DST < IMM) { + insn += insn->off; + CONT_JMP; + } + CONT; JMP_JGE_X: if (DST >= SRC) { insn += insn->off; @@ -1085,6 +1109,18 @@ out: CONT_JMP; } CONT; + JMP_JLE_X: + if (DST <= SRC) { + insn += insn->off; + CONT_JMP; + } + CONT; + JMP_JLE_K: + if (DST <= IMM) { + insn += insn->off; + CONT_JMP; + } + CONT; JMP_JSGT_X: if (((s64) DST) > ((s64) SRC)) { insn += insn->off; @@ -1097,6 +1133,18 @@ out: CONT_JMP; } CONT; + JMP_JSLT_X: + if (((s64) DST) < ((s64) SRC)) { + insn += insn->off; + CONT_JMP; + } + CONT; + JMP_JSLT_K: + if (((s64) DST) < ((s64) IMM)) { + insn += insn->off; + CONT_JMP; + } + CONT; JMP_JSGE_X: if (((s64) DST) >= ((s64) SRC)) { insn += insn->off; @@ -1109,6 +1157,18 @@ out: CONT_JMP; } CONT; + JMP_JSLE_X: + if (((s64) DST) <= ((s64) SRC)) { + insn += insn->off; + CONT_JMP; + } + CONT; + JMP_JSLE_K: + if (((s64) DST) <= ((s64) IMM)) { + insn += insn->off; + CONT_JMP; + } + CONT; JMP_JSET_X: if (DST & SRC) { insn += insn->off; diff --git a/lib/test_bpf.c b/lib/test_bpf.c index d9d5a410955c..aa8812ae6776 100644 --- a/lib/test_bpf.c +++ b/lib/test_bpf.c @@ -951,6 +951,32 @@ static struct bpf_test tests[] = { { 4, 4, 4, 3, 3 }, { { 2, 0 }, { 3, 1 }, { 4, MAX_K } }, }, + { + "JGE (jt 0), test 1", + .u.insns = { + BPF_STMT(BPF_LDX | BPF_LEN, 0), + BPF_STMT(BPF_LD | BPF_B | BPF_ABS, 2), + BPF_JUMP(BPF_JMP | BPF_JGE | BPF_X, 0, 0, 1), + BPF_STMT(BPF_RET | BPF_K, 1), + BPF_STMT(BPF_RET | BPF_K, MAX_K) + }, + CLASSIC, + { 4, 4, 4, 3, 3 }, + { { 2, 0 }, { 3, 1 }, { 4, 1 } }, + }, + { + "JGE (jt 0), test 2", + .u.insns = { + BPF_STMT(BPF_LDX | BPF_LEN, 0), + BPF_STMT(BPF_LD | BPF_B | BPF_ABS, 2), + BPF_JUMP(BPF_JMP | BPF_JGE | BPF_X, 0, 0, 1), + BPF_STMT(BPF_RET | BPF_K, 1), + BPF_STMT(BPF_RET | BPF_K, MAX_K) + }, + CLASSIC, + { 4, 4, 5, 3, 3 }, + { { 4, 1 }, { 5, 1 }, { 6, MAX_K } }, + }, { "JGE", .u.insns = { @@ -4492,6 +4518,35 @@ static struct bpf_test tests[] = { { }, { { 0, 1 } }, }, + /* BPF_JMP | BPF_JSLT | BPF_K */ + { + "JMP_JSLT_K: Signed jump: if (-2 < -1) return 1", + .u.insns_int = { + BPF_ALU32_IMM(BPF_MOV, R0, 0), + BPF_LD_IMM64(R1, 0xfffffffffffffffeLL), + BPF_JMP_IMM(BPF_JSLT, R1, -1, 1), + BPF_EXIT_INSN(), + BPF_ALU32_IMM(BPF_MOV, R0, 1), + BPF_EXIT_INSN(), + }, + INTERNAL, + { }, + { { 0, 1 } }, + }, + { + "JMP_JSLT_K: Signed jump: if (-1 < -1) return 0", + .u.insns_int = { + BPF_ALU32_IMM(BPF_MOV, R0, 1), + BPF_LD_IMM64(R1, 0xffffffffffffffffLL), + BPF_JMP_IMM(BPF_JSLT, R1, -1, 1), + BPF_EXIT_INSN(), + BPF_ALU32_IMM(BPF_MOV, R0, 0), + BPF_EXIT_INSN(), + }, + INTERNAL, + { }, + { { 0, 1 } }, + }, /* BPF_JMP | BPF_JSGT | BPF_K */ { "JMP_JSGT_K: Signed jump: if (-1 > -2) return 1", @@ -4521,6 +4576,73 @@ static struct bpf_test tests[] = { { }, { { 0, 1 } }, }, + /* BPF_JMP | BPF_JSLE | BPF_K */ + { + "JMP_JSLE_K: Signed jump: if (-2 <= -1) return 1", + .u.insns_int = { + BPF_ALU32_IMM(BPF_MOV, R0, 0), + BPF_LD_IMM64(R1, 0xfffffffffffffffeLL), + BPF_JMP_IMM(BPF_JSLE, R1, -1, 1), + BPF_EXIT_INSN(), + BPF_ALU32_IMM(BPF_MOV, R0, 1), + BPF_EXIT_INSN(), + }, + INTERNAL, + { }, + { { 0, 1 } }, + }, + { + "JMP_JSLE_K: Signed jump: if (-1 <= -1) return 1", + .u.insns_int = { + BPF_ALU32_IMM(BPF_MOV, R0, 0), + BPF_LD_IMM64(R1, 0xffffffffffffffffLL), + BPF_JMP_IMM(BPF_JSLE, R1, -1, 1), + BPF_EXIT_INSN(), + BPF_ALU32_IMM(BPF_MOV, R0, 1), + BPF_EXIT_INSN(), + }, + INTERNAL, + { }, + { { 0, 1 } }, + }, + { + "JMP_JSLE_K: Signed jump: value walk 1", + .u.insns_int = { + BPF_ALU32_IMM(BPF_MOV, R0, 0), + BPF_LD_IMM64(R1, 3), + BPF_JMP_IMM(BPF_JSLE, R1, 0, 6), + BPF_ALU64_IMM(BPF_SUB, R1, 1), + BPF_JMP_IMM(BPF_JSLE, R1, 0, 4), + BPF_ALU64_IMM(BPF_SUB, R1, 1), + BPF_JMP_IMM(BPF_JSLE, R1, 0, 2), + BPF_ALU64_IMM(BPF_SUB, R1, 1), + BPF_JMP_IMM(BPF_JSLE, R1, 0, 1), + BPF_EXIT_INSN(), /* bad exit */ + BPF_ALU32_IMM(BPF_MOV, R0, 1), /* good exit */ + BPF_EXIT_INSN(), + }, + INTERNAL, + { }, + { { 0, 1 } }, + }, + { + "JMP_JSLE_K: Signed jump: value walk 2", + .u.insns_int = { + BPF_ALU32_IMM(BPF_MOV, R0, 0), + BPF_LD_IMM64(R1, 3), + BPF_JMP_IMM(BPF_JSLE, R1, 0, 4), + BPF_ALU64_IMM(BPF_SUB, R1, 2), + BPF_JMP_IMM(BPF_JSLE, R1, 0, 2), + BPF_ALU64_IMM(BPF_SUB, R1, 2), + BPF_JMP_IMM(BPF_JSLE, R1, 0, 1), + BPF_EXIT_INSN(), /* bad exit */ + BPF_ALU32_IMM(BPF_MOV, R0, 1), /* good exit */ + BPF_EXIT_INSN(), + }, + INTERNAL, + { }, + { { 0, 1 } }, + }, /* BPF_JMP | BPF_JSGE | BPF_K */ { "JMP_JSGE_K: Signed jump: if (-1 >= -2) return 1", @@ -4617,6 +4739,35 @@ static struct bpf_test tests[] = { { }, { { 0, 1 } }, }, + /* BPF_JMP | BPF_JLT | BPF_K */ + { + "JMP_JLT_K: if (2 < 3) return 1", + .u.insns_int = { + BPF_ALU32_IMM(BPF_MOV, R0, 0), + BPF_LD_IMM64(R1, 2), + BPF_JMP_IMM(BPF_JLT, R1, 3, 1), + BPF_EXIT_INSN(), + BPF_ALU32_IMM(BPF_MOV, R0, 1), + BPF_EXIT_INSN(), + }, + INTERNAL, + { }, + { { 0, 1 } }, + }, + { + "JMP_JGT_K: Unsigned jump: if (1 < -1) return 1", + .u.insns_int = { + BPF_ALU32_IMM(BPF_MOV, R0, 0), + BPF_LD_IMM64(R1, 1), + BPF_JMP_IMM(BPF_JLT, R1, -1, 1), + BPF_EXIT_INSN(), + BPF_ALU32_IMM(BPF_MOV, R0, 1), + BPF_EXIT_INSN(), + }, + INTERNAL, + { }, + { { 0, 1 } }, + }, /* BPF_JMP | BPF_JGE | BPF_K */ { "JMP_JGE_K: if (3 >= 2) return 1", @@ -4632,6 +4783,21 @@ static struct bpf_test tests[] = { { }, { { 0, 1 } }, }, + /* BPF_JMP | BPF_JLE | BPF_K */ + { + "JMP_JLE_K: if (2 <= 3) return 1", + .u.insns_int = { + BPF_ALU32_IMM(BPF_MOV, R0, 0), + BPF_LD_IMM64(R1, 2), + BPF_JMP_IMM(BPF_JLE, R1, 3, 1), + BPF_EXIT_INSN(), + BPF_ALU32_IMM(BPF_MOV, R0, 1), + BPF_EXIT_INSN(), + }, + INTERNAL, + { }, + { { 0, 1 } }, + }, /* BPF_JMP | BPF_JGT | BPF_K jump backwards */ { "JMP_JGT_K: if (3 > 2) return 1 (jump backwards)", @@ -4662,6 +4828,36 @@ static struct bpf_test tests[] = { { }, { { 0, 1 } }, }, + /* BPF_JMP | BPF_JLT | BPF_K jump backwards */ + { + "JMP_JGT_K: if (2 < 3) return 1 (jump backwards)", + .u.insns_int = { + BPF_JMP_IMM(BPF_JA, 0, 0, 2), /* goto start */ + BPF_ALU32_IMM(BPF_MOV, R0, 1), /* out: */ + BPF_EXIT_INSN(), + BPF_ALU32_IMM(BPF_MOV, R0, 0), /* start: */ + BPF_LD_IMM64(R1, 2), /* note: this takes 2 insns */ + BPF_JMP_IMM(BPF_JLT, R1, 3, -6), /* goto out */ + BPF_EXIT_INSN(), + }, + INTERNAL, + { }, + { { 0, 1 } }, + }, + { + "JMP_JLE_K: if (3 <= 3) return 1", + .u.insns_int = { + BPF_ALU32_IMM(BPF_MOV, R0, 0), + BPF_LD_IMM64(R1, 3), + BPF_JMP_IMM(BPF_JLE, R1, 3, 1), + BPF_EXIT_INSN(), + BPF_ALU32_IMM(BPF_MOV, R0, 1), + BPF_EXIT_INSN(), + }, + INTERNAL, + { }, + { { 0, 1 } }, + }, /* BPF_JMP | BPF_JNE | BPF_K */ { "JMP_JNE_K: if (3 != 2) return 1", @@ -4752,6 +4948,37 @@ static struct bpf_test tests[] = { { }, { { 0, 1 } }, }, + /* BPF_JMP | BPF_JSLT | BPF_X */ + { + "JMP_JSLT_X: Signed jump: if (-2 < -1) return 1", + .u.insns_int = { + BPF_ALU32_IMM(BPF_MOV, R0, 0), + BPF_LD_IMM64(R1, -1), + BPF_LD_IMM64(R2, -2), + BPF_JMP_REG(BPF_JSLT, R2, R1, 1), + BPF_EXIT_INSN(), + BPF_ALU32_IMM(BPF_MOV, R0, 1), + BPF_EXIT_INSN(), + }, + INTERNAL, + { }, + { { 0, 1 } }, + }, + { + "JMP_JSLT_X: Signed jump: if (-1 < -1) return 0", + .u.insns_int = { + BPF_ALU32_IMM(BPF_MOV, R0, 1), + BPF_LD_IMM64(R1, -1), + BPF_LD_IMM64(R2, -1), + BPF_JMP_REG(BPF_JSLT, R1, R2, 1), + BPF_EXIT_INSN(), + BPF_ALU32_IMM(BPF_MOV, R0, 0), + BPF_EXIT_INSN(), + }, + INTERNAL, + { }, + { { 0, 1 } }, + }, /* BPF_JMP | BPF_JSGE | BPF_X */ { "JMP_JSGE_X: Signed jump: if (-1 >= -2) return 1", @@ -4783,6 +5010,37 @@ static struct bpf_test tests[] = { { }, { { 0, 1 } }, }, + /* BPF_JMP | BPF_JSLE | BPF_X */ + { + "JMP_JSLE_X: Signed jump: if (-2 <= -1) return 1", + .u.insns_int = { + BPF_ALU32_IMM(BPF_MOV, R0, 0), + BPF_LD_IMM64(R1, -1), + BPF_LD_IMM64(R2, -2), + BPF_JMP_REG(BPF_JSLE, R2, R1, 1), + BPF_EXIT_INSN(), + BPF_ALU32_IMM(BPF_MOV, R0, 1), + BPF_EXIT_INSN(), + }, + INTERNAL, + { }, + { { 0, 1 } }, + }, + { + "JMP_JSLE_X: Signed jump: if (-1 <= -1) return 1", + .u.insns_int = { + BPF_ALU32_IMM(BPF_MOV, R0, 0), + BPF_LD_IMM64(R1, -1), + BPF_LD_IMM64(R2, -1), + BPF_JMP_REG(BPF_JSLE, R1, R2, 1), + BPF_EXIT_INSN(), + BPF_ALU32_IMM(BPF_MOV, R0, 1), + BPF_EXIT_INSN(), + }, + INTERNAL, + { }, + { { 0, 1 } }, + }, /* BPF_JMP | BPF_JGT | BPF_X */ { "JMP_JGT_X: if (3 > 2) return 1", @@ -4814,6 +5072,37 @@ static struct bpf_test tests[] = { { }, { { 0, 1 } }, }, + /* BPF_JMP | BPF_JLT | BPF_X */ + { + "JMP_JLT_X: if (2 < 3) return 1", + .u.insns_int = { + BPF_ALU32_IMM(BPF_MOV, R0, 0), + BPF_LD_IMM64(R1, 3), + BPF_LD_IMM64(R2, 2), + BPF_JMP_REG(BPF_JLT, R2, R1, 1), + BPF_EXIT_INSN(), + BPF_ALU32_IMM(BPF_MOV, R0, 1), + BPF_EXIT_INSN(), + }, + INTERNAL, + { }, + { { 0, 1 } }, + }, + { + "JMP_JLT_X: Unsigned jump: if (1 < -1) return 1", + .u.insns_int = { + BPF_ALU32_IMM(BPF_MOV, R0, 0), + BPF_LD_IMM64(R1, -1), + BPF_LD_IMM64(R2, 1), + BPF_JMP_REG(BPF_JLT, R2, R1, 1), + BPF_EXIT_INSN(), + BPF_ALU32_IMM(BPF_MOV, R0, 1), + BPF_EXIT_INSN(), + }, + INTERNAL, + { }, + { { 0, 1 } }, + }, /* BPF_JMP | BPF_JGE | BPF_X */ { "JMP_JGE_X: if (3 >= 2) return 1", @@ -4845,6 +5134,37 @@ static struct bpf_test tests[] = { { }, { { 0, 1 } }, }, + /* BPF_JMP | BPF_JLE | BPF_X */ + { + "JMP_JLE_X: if (2 <= 3) return 1", + .u.insns_int = { + BPF_ALU32_IMM(BPF_MOV, R0, 0), + BPF_LD_IMM64(R1, 3), + BPF_LD_IMM64(R2, 2), + BPF_JMP_REG(BPF_JLE, R2, R1, 1), + BPF_EXIT_INSN(), + BPF_ALU32_IMM(BPF_MOV, R0, 1), + BPF_EXIT_INSN(), + }, + INTERNAL, + { }, + { { 0, 1 } }, + }, + { + "JMP_JLE_X: if (3 <= 3) return 1", + .u.insns_int = { + BPF_ALU32_IMM(BPF_MOV, R0, 0), + BPF_LD_IMM64(R1, 3), + BPF_LD_IMM64(R2, 3), + BPF_JMP_REG(BPF_JLE, R1, R2, 1), + BPF_EXIT_INSN(), + BPF_ALU32_IMM(BPF_MOV, R0, 1), + BPF_EXIT_INSN(), + }, + INTERNAL, + { }, + { { 0, 1 } }, + }, { /* Mainly testing JIT + imm64 here. */ "JMP_JGE_X: ldimm64 test 1", @@ -4890,6 +5210,50 @@ static struct bpf_test tests[] = { { }, { { 0, 1 } }, }, + { + "JMP_JLE_X: ldimm64 test 1", + .u.insns_int = { + BPF_ALU32_IMM(BPF_MOV, R0, 0), + BPF_LD_IMM64(R1, 3), + BPF_LD_IMM64(R2, 2), + BPF_JMP_REG(BPF_JLE, R2, R1, 2), + BPF_LD_IMM64(R0, 0xffffffffffffffffULL), + BPF_LD_IMM64(R0, 0xeeeeeeeeeeeeeeeeULL), + BPF_EXIT_INSN(), + }, + INTERNAL, + { }, + { { 0, 0xeeeeeeeeU } }, + }, + { + "JMP_JLE_X: ldimm64 test 2", + .u.insns_int = { + BPF_ALU32_IMM(BPF_MOV, R0, 0), + BPF_LD_IMM64(R1, 3), + BPF_LD_IMM64(R2, 2), + BPF_JMP_REG(BPF_JLE, R2, R1, 0), + BPF_LD_IMM64(R0, 0xffffffffffffffffULL), + BPF_EXIT_INSN(), + }, + INTERNAL, + { }, + { { 0, 0xffffffffU } }, + }, + { + "JMP_JLE_X: ldimm64 test 3", + .u.insns_int = { + BPF_ALU32_IMM(BPF_MOV, R0, 1), + BPF_LD_IMM64(R1, 3), + BPF_LD_IMM64(R2, 2), + BPF_JMP_REG(BPF_JLE, R2, R1, 4), + BPF_LD_IMM64(R0, 0xffffffffffffffffULL), + BPF_LD_IMM64(R0, 0xeeeeeeeeeeeeeeeeULL), + BPF_EXIT_INSN(), + }, + INTERNAL, + { }, + { { 0, 1 } }, + }, /* BPF_JMP | BPF_JNE | BPF_X */ { "JMP_JNE_X: if (3 != 2) return 1", diff --git a/net/core/filter.c b/net/core/filter.c index 78d00933dbe7..5afe3ac191ec 100644 --- a/net/core/filter.c +++ b/net/core/filter.c @@ -514,14 +514,27 @@ do_pass: break; } - /* Convert JEQ into JNE when 'jump_true' is next insn. */ - if (fp->jt == 0 && BPF_OP(fp->code) == BPF_JEQ) { - insn->code = BPF_JMP | BPF_JNE | bpf_src; + /* Convert some jumps when 'jump_true' is next insn. */ + if (fp->jt == 0) { + switch (BPF_OP(fp->code)) { + case BPF_JEQ: + insn->code = BPF_JMP | BPF_JNE | bpf_src; + break; + case BPF_JGT: + insn->code = BPF_JMP | BPF_JLE | bpf_src; + break; + case BPF_JGE: + insn->code = BPF_JMP | BPF_JLT | bpf_src; + break; + default: + goto jmp_rest; + } + target = i + fp->jf + 1; BPF_EMIT_JMP; break; } - +jmp_rest: /* Other jumps are mapped into two insns: Jxx and JA. */ target = i + fp->jt + 1; insn->code = BPF_JMP | BPF_OP(fp->code) | bpf_src; diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h index 8d9bfcca3fe4..bf3b2e230455 100644 --- a/tools/include/uapi/linux/bpf.h +++ b/tools/include/uapi/linux/bpf.h @@ -30,9 +30,14 @@ #define BPF_FROM_LE BPF_TO_LE #define BPF_FROM_BE BPF_TO_BE +/* jmp encodings */ #define BPF_JNE 0x50 /* jump != */ +#define BPF_JLT 0xa0 /* LT is unsigned, '<' */ +#define BPF_JLE 0xb0 /* LE is unsigned, '<=' */ #define BPF_JSGT 0x60 /* SGT is signed '>', GT in x86 */ #define BPF_JSGE 0x70 /* SGE is signed '>=', GE in x86 */ +#define BPF_JSLT 0xc0 /* SLT is signed, '<' */ +#define BPF_JSLE 0xd0 /* SLE is signed, '<=' */ #define BPF_CALL 0x80 /* function call */ #define BPF_EXIT 0x90 /* function return */ -- cgit v1.2.3 From b09be676e0ff25bd6d2e7637e26d349f9109ad75 Mon Sep 17 00:00:00 2001 From: Byungchul Park Date: Mon, 7 Aug 2017 16:12:52 +0900 Subject: locking/lockdep: Implement the 'crossrelease' feature Lockdep is a runtime locking correctness validator that detects and reports a deadlock or its possibility by checking dependencies between locks. It's useful since it does not report just an actual deadlock but also the possibility of a deadlock that has not actually happened yet. That enables problems to be fixed before they affect real systems. However, this facility is only applicable to typical locks, such as spinlocks and mutexes, which are normally released within the context in which they were acquired. However, synchronization primitives like page locks or completions, which are allowed to be released in any context, also create dependencies and can cause a deadlock. So lockdep should track these locks to do a better job. The 'crossrelease' implementation makes these primitives also be tracked. Signed-off-by: Byungchul Park Signed-off-by: Peter Zijlstra (Intel) Cc: Linus Torvalds Cc: Peter Zijlstra Cc: Thomas Gleixner Cc: akpm@linux-foundation.org Cc: boqun.feng@gmail.com Cc: kernel-team@lge.com Cc: kirill@shutemov.name Cc: npiggin@gmail.com Cc: walken@google.com Cc: willy@infradead.org Link: http://lkml.kernel.org/r/1502089981-21272-6-git-send-email-byungchul.park@lge.com Signed-off-by: Ingo Molnar --- include/linux/irqflags.h | 24 ++- include/linux/lockdep.h | 110 +++++++++- include/linux/sched.h | 8 + kernel/exit.c | 1 + kernel/fork.c | 4 + kernel/locking/lockdep.c | 508 ++++++++++++++++++++++++++++++++++++++++++++--- kernel/workqueue.c | 2 + lib/Kconfig.debug | 12 ++ 8 files changed, 635 insertions(+), 34 deletions(-) (limited to 'lib') diff --git a/include/linux/irqflags.h b/include/linux/irqflags.h index 5dd1272d1ab2..5fdd93bb9300 100644 --- a/include/linux/irqflags.h +++ b/include/linux/irqflags.h @@ -23,10 +23,26 @@ # define trace_softirq_context(p) ((p)->softirq_context) # define trace_hardirqs_enabled(p) ((p)->hardirqs_enabled) # define trace_softirqs_enabled(p) ((p)->softirqs_enabled) -# define trace_hardirq_enter() do { current->hardirq_context++; } while (0) -# define trace_hardirq_exit() do { current->hardirq_context--; } while (0) -# define lockdep_softirq_enter() do { current->softirq_context++; } while (0) -# define lockdep_softirq_exit() do { current->softirq_context--; } while (0) +# define trace_hardirq_enter() \ +do { \ + current->hardirq_context++; \ + crossrelease_hist_start(XHLOCK_HARD); \ +} while (0) +# define trace_hardirq_exit() \ +do { \ + current->hardirq_context--; \ + crossrelease_hist_end(XHLOCK_HARD); \ +} while (0) +# define lockdep_softirq_enter() \ +do { \ + current->softirq_context++; \ + crossrelease_hist_start(XHLOCK_SOFT); \ +} while (0) +# define lockdep_softirq_exit() \ +do { \ + current->softirq_context--; \ + crossrelease_hist_end(XHLOCK_SOFT); \ +} while (0) # define INIT_TRACE_IRQFLAGS .softirqs_enabled = 1, #else # define trace_hardirqs_on() do { } while (0) diff --git a/include/linux/lockdep.h b/include/linux/lockdep.h index 0a4c02c2d7a2..e1e0fcd99613 100644 --- a/include/linux/lockdep.h +++ b/include/linux/lockdep.h @@ -155,6 +155,12 @@ struct lockdep_map { int cpu; unsigned long ip; #endif +#ifdef CONFIG_LOCKDEP_CROSSRELEASE + /* + * Whether it's a crosslock. + */ + int cross; +#endif }; static inline void lockdep_copy_map(struct lockdep_map *to, @@ -258,8 +264,62 @@ struct held_lock { unsigned int hardirqs_off:1; unsigned int references:12; /* 32 bits */ unsigned int pin_count; +#ifdef CONFIG_LOCKDEP_CROSSRELEASE + /* + * Generation id. + * + * A value of cross_gen_id will be stored when holding this, + * which is globally increased whenever each crosslock is held. + */ + unsigned int gen_id; +#endif +}; + +#ifdef CONFIG_LOCKDEP_CROSSRELEASE +#define MAX_XHLOCK_TRACE_ENTRIES 5 + +/* + * This is for keeping locks waiting for commit so that true dependencies + * can be added at commit step. + */ +struct hist_lock { + /* + * Seperate stack_trace data. This will be used at commit step. + */ + struct stack_trace trace; + unsigned long trace_entries[MAX_XHLOCK_TRACE_ENTRIES]; + + /* + * Seperate hlock instance. This will be used at commit step. + * + * TODO: Use a smaller data structure containing only necessary + * data. However, we should make lockdep code able to handle the + * smaller one first. + */ + struct held_lock hlock; +}; + +/* + * To initialize a lock as crosslock, lockdep_init_map_crosslock() should + * be called instead of lockdep_init_map(). + */ +struct cross_lock { + /* + * Seperate hlock instance. This will be used at commit step. + * + * TODO: Use a smaller data structure containing only necessary + * data. However, we should make lockdep code able to handle the + * smaller one first. + */ + struct held_lock hlock; }; +struct lockdep_map_cross { + struct lockdep_map map; + struct cross_lock xlock; +}; +#endif + /* * Initialization, self-test and debugging-output methods: */ @@ -281,13 +341,6 @@ extern void lockdep_on(void); extern void lockdep_init_map(struct lockdep_map *lock, const char *name, struct lock_class_key *key, int subclass); -/* - * To initialize a lockdep_map statically use this macro. - * Note that _name must not be NULL. - */ -#define STATIC_LOCKDEP_MAP_INIT(_name, _key) \ - { .name = (_name), .key = (void *)(_key), } - /* * Reinitialize a lock key - for cases where there is special locking or * special initialization of locks so that the validator gets the scope @@ -460,6 +513,49 @@ struct pin_cookie { }; #endif /* !LOCKDEP */ +enum xhlock_context_t { + XHLOCK_HARD, + XHLOCK_SOFT, + XHLOCK_PROC, + XHLOCK_CTX_NR, +}; + +#ifdef CONFIG_LOCKDEP_CROSSRELEASE +extern void lockdep_init_map_crosslock(struct lockdep_map *lock, + const char *name, + struct lock_class_key *key, + int subclass); +extern void lock_commit_crosslock(struct lockdep_map *lock); + +#define STATIC_CROSS_LOCKDEP_MAP_INIT(_name, _key) \ + { .map.name = (_name), .map.key = (void *)(_key), \ + .map.cross = 1, } + +/* + * To initialize a lockdep_map statically use this macro. + * Note that _name must not be NULL. + */ +#define STATIC_LOCKDEP_MAP_INIT(_name, _key) \ + { .name = (_name), .key = (void *)(_key), .cross = 0, } + +extern void crossrelease_hist_start(enum xhlock_context_t c); +extern void crossrelease_hist_end(enum xhlock_context_t c); +extern void lockdep_init_task(struct task_struct *task); +extern void lockdep_free_task(struct task_struct *task); +#else +/* + * To initialize a lockdep_map statically use this macro. + * Note that _name must not be NULL. + */ +#define STATIC_LOCKDEP_MAP_INIT(_name, _key) \ + { .name = (_name), .key = (void *)(_key), } + +static inline void crossrelease_hist_start(enum xhlock_context_t c) {} +static inline void crossrelease_hist_end(enum xhlock_context_t c) {} +static inline void lockdep_init_task(struct task_struct *task) {} +static inline void lockdep_free_task(struct task_struct *task) {} +#endif + #ifdef CONFIG_LOCK_STAT extern void lock_contended(struct lockdep_map *lock, unsigned long ip); diff --git a/include/linux/sched.h b/include/linux/sched.h index 57db70ec70d0..5235fba537fc 100644 --- a/include/linux/sched.h +++ b/include/linux/sched.h @@ -848,6 +848,14 @@ struct task_struct { struct held_lock held_locks[MAX_LOCK_DEPTH]; #endif +#ifdef CONFIG_LOCKDEP_CROSSRELEASE +#define MAX_XHLOCKS_NR 64UL + struct hist_lock *xhlocks; /* Crossrelease history locks */ + unsigned int xhlock_idx; + /* For restoring at history boundaries */ + unsigned int xhlock_idx_hist[XHLOCK_CTX_NR]; +#endif + #ifdef CONFIG_UBSAN unsigned int in_ubsan; #endif diff --git a/kernel/exit.c b/kernel/exit.c index c5548faa9f37..fa72d57db747 100644 --- a/kernel/exit.c +++ b/kernel/exit.c @@ -920,6 +920,7 @@ void __noreturn do_exit(long code) exit_rcu(); TASKS_RCU(__srcu_read_unlock(&tasks_rcu_exit_srcu, tasks_rcu_i)); + lockdep_free_task(tsk); do_task_dead(); } EXPORT_SYMBOL_GPL(do_exit); diff --git a/kernel/fork.c b/kernel/fork.c index 17921b0390b4..cbf2221ee81a 100644 --- a/kernel/fork.c +++ b/kernel/fork.c @@ -484,6 +484,8 @@ void __init fork_init(void) cpuhp_setup_state(CPUHP_BP_PREPARE_DYN, "fork:vm_stack_cache", NULL, free_vm_stack_cache); #endif + + lockdep_init_task(&init_task); } int __weak arch_dup_task_struct(struct task_struct *dst, @@ -1691,6 +1693,7 @@ static __latent_entropy struct task_struct *copy_process( p->lockdep_depth = 0; /* no locks held yet */ p->curr_chain_key = 0; p->lockdep_recursion = 0; + lockdep_init_task(p); #endif #ifdef CONFIG_DEBUG_MUTEXES @@ -1949,6 +1952,7 @@ bad_fork_cleanup_audit: bad_fork_cleanup_perf: perf_event_free_task(p); bad_fork_cleanup_policy: + lockdep_free_task(p); #ifdef CONFIG_NUMA mpol_put(p->mempolicy); bad_fork_cleanup_threadgroup_lock: diff --git a/kernel/locking/lockdep.c b/kernel/locking/lockdep.c index 841828ba35b9..56f69cc53ddc 100644 --- a/kernel/locking/lockdep.c +++ b/kernel/locking/lockdep.c @@ -58,6 +58,10 @@ #define CREATE_TRACE_POINTS #include +#ifdef CONFIG_LOCKDEP_CROSSRELEASE +#include +#endif + #ifdef CONFIG_PROVE_LOCKING int prove_locking = 1; module_param(prove_locking, int, 0644); @@ -724,6 +728,18 @@ look_up_lock_class(struct lockdep_map *lock, unsigned int subclass) return is_static || static_obj(lock->key) ? NULL : ERR_PTR(-EINVAL); } +#ifdef CONFIG_LOCKDEP_CROSSRELEASE +static void cross_init(struct lockdep_map *lock, int cross); +static int cross_lock(struct lockdep_map *lock); +static int lock_acquire_crosslock(struct held_lock *hlock); +static int lock_release_crosslock(struct lockdep_map *lock); +#else +static inline void cross_init(struct lockdep_map *lock, int cross) {} +static inline int cross_lock(struct lockdep_map *lock) { return 0; } +static inline int lock_acquire_crosslock(struct held_lock *hlock) { return 2; } +static inline int lock_release_crosslock(struct lockdep_map *lock) { return 2; } +#endif + /* * Register a lock's class in the hash-table, if the class is not present * yet. Otherwise we look it up. We cache the result in the lock object @@ -1795,6 +1811,9 @@ check_deadlock(struct task_struct *curr, struct held_lock *next, if (nest) return 2; + if (cross_lock(prev->instance)) + continue; + return print_deadlock_bug(curr, prev, next); } return 1; @@ -1962,30 +1981,36 @@ check_prevs_add(struct task_struct *curr, struct held_lock *next) int distance = curr->lockdep_depth - depth + 1; hlock = curr->held_locks + depth - 1; /* - * Only non-recursive-read entries get new dependencies - * added: + * Only non-crosslock entries get new dependencies added. + * Crosslock entries will be added by commit later: */ - if (hlock->read != 2 && hlock->check) { - int ret = check_prev_add(curr, hlock, next, - distance, &trace, save); - if (!ret) - return 0; - + if (!cross_lock(hlock->instance)) { /* - * Stop saving stack_trace if save_trace() was - * called at least once: + * Only non-recursive-read entries get new dependencies + * added: */ - if (save && ret == 2) - save = NULL; + if (hlock->read != 2 && hlock->check) { + int ret = check_prev_add(curr, hlock, next, + distance, &trace, save); + if (!ret) + return 0; - /* - * Stop after the first non-trylock entry, - * as non-trylock entries have added their - * own direct dependencies already, so this - * lock is connected to them indirectly: - */ - if (!hlock->trylock) - break; + /* + * Stop saving stack_trace if save_trace() was + * called at least once: + */ + if (save && ret == 2) + save = NULL; + + /* + * Stop after the first non-trylock entry, + * as non-trylock entries have added their + * own direct dependencies already, so this + * lock is connected to them indirectly: + */ + if (!hlock->trylock) + break; + } } depth--; /* @@ -3176,7 +3201,7 @@ static int mark_lock(struct task_struct *curr, struct held_lock *this, /* * Initialize a lock instance's lock-class mapping info: */ -void lockdep_init_map(struct lockdep_map *lock, const char *name, +static void __lockdep_init_map(struct lockdep_map *lock, const char *name, struct lock_class_key *key, int subclass) { int i; @@ -3234,8 +3259,25 @@ void lockdep_init_map(struct lockdep_map *lock, const char *name, raw_local_irq_restore(flags); } } + +void lockdep_init_map(struct lockdep_map *lock, const char *name, + struct lock_class_key *key, int subclass) +{ + cross_init(lock, 0); + __lockdep_init_map(lock, name, key, subclass); +} EXPORT_SYMBOL_GPL(lockdep_init_map); +#ifdef CONFIG_LOCKDEP_CROSSRELEASE +void lockdep_init_map_crosslock(struct lockdep_map *lock, const char *name, + struct lock_class_key *key, int subclass) +{ + cross_init(lock, 1); + __lockdep_init_map(lock, name, key, subclass); +} +EXPORT_SYMBOL_GPL(lockdep_init_map_crosslock); +#endif + struct lock_class_key __lockdep_no_validate__; EXPORT_SYMBOL_GPL(__lockdep_no_validate__); @@ -3291,6 +3333,7 @@ static int __lock_acquire(struct lockdep_map *lock, unsigned int subclass, int chain_head = 0; int class_idx; u64 chain_key; + int ret; if (unlikely(!debug_locks)) return 0; @@ -3339,7 +3382,8 @@ static int __lock_acquire(struct lockdep_map *lock, unsigned int subclass, class_idx = class - lock_classes + 1; - if (depth) { + /* TODO: nest_lock is not implemented for crosslock yet. */ + if (depth && !cross_lock(lock)) { hlock = curr->held_locks + depth - 1; if (hlock->class_idx == class_idx && nest_lock) { if (hlock->references) { @@ -3427,6 +3471,14 @@ static int __lock_acquire(struct lockdep_map *lock, unsigned int subclass, if (!validate_chain(curr, lock, hlock, chain_head, chain_key)) return 0; + ret = lock_acquire_crosslock(hlock); + /* + * 2 means normal acquire operations are needed. Otherwise, it's + * ok just to return with '0:fail, 1:success'. + */ + if (ret != 2) + return ret; + curr->curr_chain_key = chain_key; curr->lockdep_depth++; check_chain_key(curr); @@ -3664,11 +3716,19 @@ __lock_release(struct lockdep_map *lock, int nested, unsigned long ip) struct task_struct *curr = current; struct held_lock *hlock; unsigned int depth; - int i; + int ret, i; if (unlikely(!debug_locks)) return 0; + ret = lock_release_crosslock(lock); + /* + * 2 means normal release operations are needed. Otherwise, it's + * ok just to return with '0:fail, 1:success'. + */ + if (ret != 2) + return ret; + depth = curr->lockdep_depth; /* * So we're all set to release this lock.. wait what lock? We don't @@ -4532,6 +4592,13 @@ asmlinkage __visible void lockdep_sys_exit(void) curr->comm, curr->pid); lockdep_print_held_locks(curr); } + + /* + * The lock history for each syscall should be independent. So wipe the + * slate clean on return to userspace. + */ + crossrelease_hist_end(XHLOCK_PROC); + crossrelease_hist_start(XHLOCK_PROC); } void lockdep_rcu_suspicious(const char *file, const int line, const char *s) @@ -4580,3 +4647,398 @@ void lockdep_rcu_suspicious(const char *file, const int line, const char *s) dump_stack(); } EXPORT_SYMBOL_GPL(lockdep_rcu_suspicious); + +#ifdef CONFIG_LOCKDEP_CROSSRELEASE + +/* + * Crossrelease works by recording a lock history for each thread and + * connecting those historic locks that were taken after the + * wait_for_completion() in the complete() context. + * + * Task-A Task-B + * + * mutex_lock(&A); + * mutex_unlock(&A); + * + * wait_for_completion(&C); + * lock_acquire_crosslock(); + * atomic_inc_return(&cross_gen_id); + * | + * | mutex_lock(&B); + * | mutex_unlock(&B); + * | + * | complete(&C); + * `-- lock_commit_crosslock(); + * + * Which will then add a dependency between B and C. + */ + +#define xhlock(i) (current->xhlocks[(i) % MAX_XHLOCKS_NR]) + +/* + * Whenever a crosslock is held, cross_gen_id will be increased. + */ +static atomic_t cross_gen_id; /* Can be wrapped */ + +/* + * Lock history stacks; we have 3 nested lock history stacks: + * + * Hard IRQ + * Soft IRQ + * History / Task + * + * The thing is that once we complete a (Hard/Soft) IRQ the future task locks + * should not depend on any of the locks observed while running the IRQ. + * + * So what we do is rewind the history buffer and erase all our knowledge of + * that temporal event. + */ + +/* + * We need this to annotate lock history boundaries. Take for instance + * workqueues; each work is independent of the last. The completion of a future + * work does not depend on the completion of a past work (in general). + * Therefore we must not carry that (lock) dependency across works. + * + * This is true for many things; pretty much all kthreads fall into this + * pattern, where they have an 'idle' state and future completions do not + * depend on past completions. Its just that since they all have the 'same' + * form -- the kthread does the same over and over -- it doesn't typically + * matter. + * + * The same is true for system-calls, once a system call is completed (we've + * returned to userspace) the next system call does not depend on the lock + * history of the previous system call. + */ +void crossrelease_hist_start(enum xhlock_context_t c) +{ + if (current->xhlocks) + current->xhlock_idx_hist[c] = current->xhlock_idx; +} + +void crossrelease_hist_end(enum xhlock_context_t c) +{ + if (current->xhlocks) + current->xhlock_idx = current->xhlock_idx_hist[c]; +} + +static int cross_lock(struct lockdep_map *lock) +{ + return lock ? lock->cross : 0; +} + +/* + * This is needed to decide the relationship between wrapable variables. + */ +static inline int before(unsigned int a, unsigned int b) +{ + return (int)(a - b) < 0; +} + +static inline struct lock_class *xhlock_class(struct hist_lock *xhlock) +{ + return hlock_class(&xhlock->hlock); +} + +static inline struct lock_class *xlock_class(struct cross_lock *xlock) +{ + return hlock_class(&xlock->hlock); +} + +/* + * Should we check a dependency with previous one? + */ +static inline int depend_before(struct held_lock *hlock) +{ + return hlock->read != 2 && hlock->check && !hlock->trylock; +} + +/* + * Should we check a dependency with next one? + */ +static inline int depend_after(struct held_lock *hlock) +{ + return hlock->read != 2 && hlock->check; +} + +/* + * Check if the xhlock is valid, which would be false if, + * + * 1. Has not used after initializaion yet. + * + * Remind hist_lock is implemented as a ring buffer. + */ +static inline int xhlock_valid(struct hist_lock *xhlock) +{ + /* + * xhlock->hlock.instance must be !NULL. + */ + return !!xhlock->hlock.instance; +} + +/* + * Record a hist_lock entry. + * + * Irq disable is only required. + */ +static void add_xhlock(struct held_lock *hlock) +{ + unsigned int idx = ++current->xhlock_idx; + struct hist_lock *xhlock = &xhlock(idx); + +#ifdef CONFIG_DEBUG_LOCKDEP + /* + * This can be done locklessly because they are all task-local + * state, we must however ensure IRQs are disabled. + */ + WARN_ON_ONCE(!irqs_disabled()); +#endif + + /* Initialize hist_lock's members */ + xhlock->hlock = *hlock; + + xhlock->trace.nr_entries = 0; + xhlock->trace.max_entries = MAX_XHLOCK_TRACE_ENTRIES; + xhlock->trace.entries = xhlock->trace_entries; + xhlock->trace.skip = 3; + save_stack_trace(&xhlock->trace); +} + +static inline int same_context_xhlock(struct hist_lock *xhlock) +{ + return xhlock->hlock.irq_context == task_irq_context(current); +} + +/* + * This should be lockless as far as possible because this would be + * called very frequently. + */ +static void check_add_xhlock(struct held_lock *hlock) +{ + /* + * Record a hist_lock, only in case that acquisitions ahead + * could depend on the held_lock. For example, if the held_lock + * is trylock then acquisitions ahead never depends on that. + * In that case, we don't need to record it. Just return. + */ + if (!current->xhlocks || !depend_before(hlock)) + return; + + add_xhlock(hlock); +} + +/* + * For crosslock. + */ +static int add_xlock(struct held_lock *hlock) +{ + struct cross_lock *xlock; + unsigned int gen_id; + + if (!graph_lock()) + return 0; + + xlock = &((struct lockdep_map_cross *)hlock->instance)->xlock; + + gen_id = (unsigned int)atomic_inc_return(&cross_gen_id); + xlock->hlock = *hlock; + xlock->hlock.gen_id = gen_id; + graph_unlock(); + + return 1; +} + +/* + * Called for both normal and crosslock acquires. Normal locks will be + * pushed on the hist_lock queue. Cross locks will record state and + * stop regular lock_acquire() to avoid being placed on the held_lock + * stack. + * + * Return: 0 - failure; + * 1 - crosslock, done; + * 2 - normal lock, continue to held_lock[] ops. + */ +static int lock_acquire_crosslock(struct held_lock *hlock) +{ + /* + * CONTEXT 1 CONTEXT 2 + * --------- --------- + * lock A (cross) + * X = atomic_inc_return(&cross_gen_id) + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Y = atomic_read_acquire(&cross_gen_id) + * lock B + * + * atomic_read_acquire() is for ordering between A and B, + * IOW, A happens before B, when CONTEXT 2 see Y >= X. + * + * Pairs with atomic_inc_return() in add_xlock(). + */ + hlock->gen_id = (unsigned int)atomic_read_acquire(&cross_gen_id); + + if (cross_lock(hlock->instance)) + return add_xlock(hlock); + + check_add_xhlock(hlock); + return 2; +} + +static int copy_trace(struct stack_trace *trace) +{ + unsigned long *buf = stack_trace + nr_stack_trace_entries; + unsigned int max_nr = MAX_STACK_TRACE_ENTRIES - nr_stack_trace_entries; + unsigned int nr = min(max_nr, trace->nr_entries); + + trace->nr_entries = nr; + memcpy(buf, trace->entries, nr * sizeof(trace->entries[0])); + trace->entries = buf; + nr_stack_trace_entries += nr; + + if (nr_stack_trace_entries >= MAX_STACK_TRACE_ENTRIES-1) { + if (!debug_locks_off_graph_unlock()) + return 0; + + print_lockdep_off("BUG: MAX_STACK_TRACE_ENTRIES too low!"); + dump_stack(); + + return 0; + } + + return 1; +} + +static int commit_xhlock(struct cross_lock *xlock, struct hist_lock *xhlock) +{ + unsigned int xid, pid; + u64 chain_key; + + xid = xlock_class(xlock) - lock_classes; + chain_key = iterate_chain_key((u64)0, xid); + pid = xhlock_class(xhlock) - lock_classes; + chain_key = iterate_chain_key(chain_key, pid); + + if (lookup_chain_cache(chain_key)) + return 1; + + if (!add_chain_cache_classes(xid, pid, xhlock->hlock.irq_context, + chain_key)) + return 0; + + if (!check_prev_add(current, &xlock->hlock, &xhlock->hlock, 1, + &xhlock->trace, copy_trace)) + return 0; + + return 1; +} + +static void commit_xhlocks(struct cross_lock *xlock) +{ + unsigned int cur = current->xhlock_idx; + unsigned int i; + + if (!graph_lock()) + return; + + for (i = 0; i < MAX_XHLOCKS_NR; i++) { + struct hist_lock *xhlock = &xhlock(cur - i); + + if (!xhlock_valid(xhlock)) + break; + + if (before(xhlock->hlock.gen_id, xlock->hlock.gen_id)) + break; + + if (!same_context_xhlock(xhlock)) + break; + + /* + * commit_xhlock() returns 0 with graph_lock already + * released if fail. + */ + if (!commit_xhlock(xlock, xhlock)) + return; + } + + graph_unlock(); +} + +void lock_commit_crosslock(struct lockdep_map *lock) +{ + struct cross_lock *xlock; + unsigned long flags; + + if (unlikely(!debug_locks || current->lockdep_recursion)) + return; + + if (!current->xhlocks) + return; + + /* + * Do commit hist_locks with the cross_lock, only in case that + * the cross_lock could depend on acquisitions after that. + * + * For example, if the cross_lock does not have the 'check' flag + * then we don't need to check dependencies and commit for that. + * Just skip it. In that case, of course, the cross_lock does + * not depend on acquisitions ahead, either. + * + * WARNING: Don't do that in add_xlock() in advance. When an + * acquisition context is different from the commit context, + * invalid(skipped) cross_lock might be accessed. + */ + if (!depend_after(&((struct lockdep_map_cross *)lock)->xlock.hlock)) + return; + + raw_local_irq_save(flags); + check_flags(flags); + current->lockdep_recursion = 1; + xlock = &((struct lockdep_map_cross *)lock)->xlock; + commit_xhlocks(xlock); + current->lockdep_recursion = 0; + raw_local_irq_restore(flags); +} +EXPORT_SYMBOL_GPL(lock_commit_crosslock); + +/* + * Return: 1 - crosslock, done; + * 2 - normal lock, continue to held_lock[] ops. + */ +static int lock_release_crosslock(struct lockdep_map *lock) +{ + return cross_lock(lock) ? 1 : 2; +} + +static void cross_init(struct lockdep_map *lock, int cross) +{ + lock->cross = cross; + + /* + * Crossrelease assumes that the ring buffer size of xhlocks + * is aligned with power of 2. So force it on build. + */ + BUILD_BUG_ON(MAX_XHLOCKS_NR & (MAX_XHLOCKS_NR - 1)); +} + +void lockdep_init_task(struct task_struct *task) +{ + int i; + + task->xhlock_idx = UINT_MAX; + + for (i = 0; i < XHLOCK_CTX_NR; i++) + task->xhlock_idx_hist[i] = UINT_MAX; + + task->xhlocks = kzalloc(sizeof(struct hist_lock) * MAX_XHLOCKS_NR, + GFP_KERNEL); +} + +void lockdep_free_task(struct task_struct *task) +{ + if (task->xhlocks) { + void *tmp = task->xhlocks; + /* Diable crossrelease for current */ + task->xhlocks = NULL; + kfree(tmp); + } +} +#endif diff --git a/kernel/workqueue.c b/kernel/workqueue.c index ca937b0c3a96..e86733a8b344 100644 --- a/kernel/workqueue.c +++ b/kernel/workqueue.c @@ -2093,6 +2093,7 @@ __acquires(&pool->lock) lock_map_acquire_read(&pwq->wq->lockdep_map); lock_map_acquire(&lockdep_map); + crossrelease_hist_start(XHLOCK_PROC); trace_workqueue_execute_start(work); worker->current_func(work); /* @@ -2100,6 +2101,7 @@ __acquires(&pool->lock) * point will only record its address. */ trace_workqueue_execute_end(work); + crossrelease_hist_end(XHLOCK_PROC); lock_map_release(&lockdep_map); lock_map_release(&pwq->wq->lockdep_map); diff --git a/lib/Kconfig.debug b/lib/Kconfig.debug index 98fe715522e8..c6038f23bb1a 100644 --- a/lib/Kconfig.debug +++ b/lib/Kconfig.debug @@ -1150,6 +1150,18 @@ config LOCK_STAT CONFIG_LOCK_STAT defines "contended" and "acquired" lock events. (CONFIG_LOCKDEP defines "acquire" and "release" events.) +config LOCKDEP_CROSSRELEASE + bool "Lock debugging: make lockdep work for crosslocks" + depends on PROVE_LOCKING + default n + help + This makes lockdep work for crosslock which is a lock allowed to + be released in a different context from the acquisition context. + Normally a lock must be released in the context acquiring the lock. + However, relexing this constraint helps synchronization primitives + such as page locks or completions can use the lock correctness + detector, lockdep. + config DEBUG_LOCKDEP bool "Lock dependency engine debugging" depends on DEBUG_KERNEL && LOCKDEP -- cgit v1.2.3 From cd8084f91c02c1afd256a39aa833bff737631304 Mon Sep 17 00:00:00 2001 From: Byungchul Park Date: Mon, 7 Aug 2017 16:12:56 +0900 Subject: locking/lockdep: Apply crossrelease to completions Although wait_for_completion() and its family can cause deadlock, the lock correctness validator could not be applied to them until now, because things like complete() are usually called in a different context from the waiting context, which violates lockdep's assumption. Thanks to CONFIG_LOCKDEP_CROSSRELEASE, we can now apply the lockdep detector to those completion operations. Applied it. Signed-off-by: Byungchul Park Signed-off-by: Peter Zijlstra (Intel) Cc: Linus Torvalds Cc: Peter Zijlstra Cc: Thomas Gleixner Cc: akpm@linux-foundation.org Cc: boqun.feng@gmail.com Cc: kernel-team@lge.com Cc: kirill@shutemov.name Cc: npiggin@gmail.com Cc: walken@google.com Cc: willy@infradead.org Link: http://lkml.kernel.org/r/1502089981-21272-10-git-send-email-byungchul.park@lge.com Signed-off-by: Ingo Molnar --- include/linux/completion.h | 45 ++++++++++++++++++++++++++++++++++++++++++++- kernel/sched/completion.c | 11 +++++++++++ lib/Kconfig.debug | 9 +++++++++ 3 files changed, 64 insertions(+), 1 deletion(-) (limited to 'lib') diff --git a/include/linux/completion.h b/include/linux/completion.h index 5d5aaae3af43..9bcebf509b4d 100644 --- a/include/linux/completion.h +++ b/include/linux/completion.h @@ -9,6 +9,9 @@ */ #include +#ifdef CONFIG_LOCKDEP_COMPLETE +#include +#endif /* * struct completion - structure used to maintain state for a "completion" @@ -25,10 +28,50 @@ struct completion { unsigned int done; wait_queue_head_t wait; +#ifdef CONFIG_LOCKDEP_COMPLETE + struct lockdep_map_cross map; +#endif }; +#ifdef CONFIG_LOCKDEP_COMPLETE +static inline void complete_acquire(struct completion *x) +{ + lock_acquire_exclusive((struct lockdep_map *)&x->map, 0, 0, NULL, _RET_IP_); +} + +static inline void complete_release(struct completion *x) +{ + lock_release((struct lockdep_map *)&x->map, 0, _RET_IP_); +} + +static inline void complete_release_commit(struct completion *x) +{ + lock_commit_crosslock((struct lockdep_map *)&x->map); +} + +#define init_completion(x) \ +do { \ + static struct lock_class_key __key; \ + lockdep_init_map_crosslock((struct lockdep_map *)&(x)->map, \ + "(complete)" #x, \ + &__key, 0); \ + __init_completion(x); \ +} while (0) +#else +#define init_completion(x) __init_completion(x) +static inline void complete_acquire(struct completion *x) {} +static inline void complete_release(struct completion *x) {} +static inline void complete_release_commit(struct completion *x) {} +#endif + +#ifdef CONFIG_LOCKDEP_COMPLETE +#define COMPLETION_INITIALIZER(work) \ + { 0, __WAIT_QUEUE_HEAD_INITIALIZER((work).wait), \ + STATIC_CROSS_LOCKDEP_MAP_INIT("(complete)" #work, &(work)) } +#else #define COMPLETION_INITIALIZER(work) \ { 0, __WAIT_QUEUE_HEAD_INITIALIZER((work).wait) } +#endif #define COMPLETION_INITIALIZER_ONSTACK(work) \ ({ init_completion(&work); work; }) @@ -70,7 +113,7 @@ struct completion { * This inline function will initialize a dynamically created completion * structure. */ -static inline void init_completion(struct completion *x) +static inline void __init_completion(struct completion *x) { x->done = 0; init_waitqueue_head(&x->wait); diff --git a/kernel/sched/completion.c b/kernel/sched/completion.c index 13fc5ae9bf2f..566b6ec7b6fe 100644 --- a/kernel/sched/completion.c +++ b/kernel/sched/completion.c @@ -32,6 +32,12 @@ void complete(struct completion *x) unsigned long flags; spin_lock_irqsave(&x->wait.lock, flags); + + /* + * Perform commit of crossrelease here. + */ + complete_release_commit(x); + if (x->done != UINT_MAX) x->done++; __wake_up_locked(&x->wait, TASK_NORMAL, 1); @@ -92,9 +98,14 @@ __wait_for_common(struct completion *x, { might_sleep(); + complete_acquire(x); + spin_lock_irq(&x->wait.lock); timeout = do_wait_for_common(x, action, timeout, state); spin_unlock_irq(&x->wait.lock); + + complete_release(x); + return timeout; } diff --git a/lib/Kconfig.debug b/lib/Kconfig.debug index c6038f23bb1a..ebd40d3b0b28 100644 --- a/lib/Kconfig.debug +++ b/lib/Kconfig.debug @@ -1162,6 +1162,15 @@ config LOCKDEP_CROSSRELEASE such as page locks or completions can use the lock correctness detector, lockdep. +config LOCKDEP_COMPLETE + bool "Lock debugging: allow completions to use deadlock detector" + depends on PROVE_LOCKING + select LOCKDEP_CROSSRELEASE + default n + help + A deadlock caused by wait_for_completion() and complete() can be + detected by lockdep using crossrelease feature. + config DEBUG_LOCKDEP bool "Lock dependency engine debugging" depends on DEBUG_KERNEL && LOCKDEP -- cgit v1.2.3 From c92316bf8e94830a0225f2e904cbdbd173768419 Mon Sep 17 00:00:00 2001 From: "Luis R. Rodriguez" Date: Thu, 20 Jul 2017 13:13:42 -0700 Subject: test_firmware: add batched firmware tests The firmware API has a feature to enable batching requests for the same fil e under one worker, so only one lookup is done. This only triggers if we so happen to schedule two lookups for same file around the same time, or if release_firmware() has not been called for a successful firmware call. This can happen for instance if you happen to have multiple devices and one device driver for certain drivers where the stars line up scheduling wise. This adds a new sync and async test trigger. Instead of adding a new trigger for each new test type we make the tests a bit configurable so that we could configure the tests in userspace and just kick a test through a few basic triggers. With this, for instance the two types of sync requests: o request_firmware() and o request_firmware_direct() can be modified with a knob. Likewise the two type of async requests: o request_firmware_nowait(uevent=true) and o request_firmware_nowait(uevent=false) can be configured with another knob. The call request_firmware_into_buf() has no users... yet. The old tests are left in place as-is given they serve a few other purposes which we are currently not interested in also testing yet. This will change later as we will be able to just consolidate all tests under a few basic triggers with just one general configuration setup. We perform two types of tests, one for where the file is present and one for where the file is not present. All test tests go tested and they now pass for the following 3 kernel builds possible for the firmware API: 0. Most distro setup: CONFIG_FW_LOADER_USER_HELPER_FALLBACK=n CONFIG_FW_LOADER_USER_HELPER=y 1. Android: CONFIG_FW_LOADER_USER_HELPER_FALLBACK=y CONFIG_FW_LOADER_USER_HELPER=y 2. Rare build: CONFIG_FW_LOADER_USER_HELPER_FALLBACK=n CONFIG_FW_LOADER_USER_HELPER=n Signed-off-by: Luis R. Rodriguez Signed-off-by: Greg Kroah-Hartman --- lib/test_firmware.c | 710 ++++++++++++++++++++++ tools/testing/selftests/firmware/fw_filesystem.sh | 241 +++++++- 2 files changed, 949 insertions(+), 2 deletions(-) (limited to 'lib') diff --git a/lib/test_firmware.c b/lib/test_firmware.c index 09371b0a9baf..64a4c76cba2b 100644 --- a/lib/test_firmware.c +++ b/lib/test_firmware.c @@ -19,10 +19,85 @@ #include #include #include +#include +#include + +#define TEST_FIRMWARE_NAME "test-firmware.bin" +#define TEST_FIRMWARE_NUM_REQS 4 static DEFINE_MUTEX(test_fw_mutex); static const struct firmware *test_firmware; +struct test_batched_req { + u8 idx; + int rc; + bool sent; + const struct firmware *fw; + const char *name; + struct completion completion; + struct task_struct *task; + struct device *dev; +}; + +/** + * test_config - represents configuration for the test for different triggers + * + * @name: the name of the firmware file to look for + * @sync_direct: when the sync trigger is used if this is true + * request_firmware_direct() will be used instead. + * @send_uevent: whether or not to send a uevent for async requests + * @num_requests: number of requests to try per test case. This is trigger + * specific. + * @reqs: stores all requests information + * @read_fw_idx: index of thread from which we want to read firmware results + * from through the read_fw trigger. + * @test_result: a test may use this to collect the result from the call + * of the request_firmware*() calls used in their tests. In order of + * priority we always keep first any setup error. If no setup errors were + * found then we move on to the first error encountered while running the + * API. Note that for async calls this typically will be a successful + * result (0) unless of course you've used bogus parameters, or the system + * is out of memory. In the async case the callback is expected to do a + * bit more homework to figure out what happened, unfortunately the only + * information passed today on error is the fact that no firmware was + * found so we can only assume -ENOENT on async calls if the firmware is + * NULL. + * + * Errors you can expect: + * + * API specific: + * + * 0: success for sync, for async it means request was sent + * -EINVAL: invalid parameters or request + * -ENOENT: files not found + * + * System environment: + * + * -ENOMEM: memory pressure on system + * -ENODEV: out of number of devices to test + * -EINVAL: an unexpected error has occurred + * @req_firmware: if @sync_direct is true this is set to + * request_firmware_direct(), otherwise request_firmware() + */ +struct test_config { + char *name; + bool sync_direct; + bool send_uevent; + u8 num_requests; + u8 read_fw_idx; + + /* + * These below don't belong her but we'll move them once we create + * a struct fw_test_device and stuff the misc_dev under there later. + */ + struct test_batched_req *reqs; + int test_result; + int (*req_firmware)(const struct firmware **fw, const char *name, + struct device *device); +}; + +struct test_config *test_fw_config; + static ssize_t test_fw_misc_read(struct file *f, char __user *buf, size_t size, loff_t *offset) { @@ -42,6 +117,338 @@ static const struct file_operations test_fw_fops = { .read = test_fw_misc_read, }; +static void __test_release_all_firmware(void) +{ + struct test_batched_req *req; + u8 i; + + if (!test_fw_config->reqs) + return; + + for (i = 0; i < test_fw_config->num_requests; i++) { + req = &test_fw_config->reqs[i]; + if (req->fw) + release_firmware(req->fw); + } + + vfree(test_fw_config->reqs); + test_fw_config->reqs = NULL; +} + +static void test_release_all_firmware(void) +{ + mutex_lock(&test_fw_mutex); + __test_release_all_firmware(); + mutex_unlock(&test_fw_mutex); +} + + +static void __test_firmware_config_free(void) +{ + __test_release_all_firmware(); + kfree_const(test_fw_config->name); + test_fw_config->name = NULL; +} + +/* + * XXX: move to kstrncpy() once merged. + * + * Users should use kfree_const() when freeing these. + */ +static int __kstrncpy(char **dst, const char *name, size_t count, gfp_t gfp) +{ + *dst = kstrndup(name, count, gfp); + if (!*dst) + return -ENOSPC; + return count; +} + +static int __test_firmware_config_init(void) +{ + int ret; + + ret = __kstrncpy(&test_fw_config->name, TEST_FIRMWARE_NAME, + strlen(TEST_FIRMWARE_NAME), GFP_KERNEL); + if (ret < 0) + goto out; + + test_fw_config->num_requests = TEST_FIRMWARE_NUM_REQS; + test_fw_config->send_uevent = true; + test_fw_config->sync_direct = false; + test_fw_config->req_firmware = request_firmware; + test_fw_config->test_result = 0; + test_fw_config->reqs = NULL; + + return 0; + +out: + __test_firmware_config_free(); + return ret; +} + +static ssize_t reset_store(struct device *dev, + struct device_attribute *attr, + const char *buf, size_t count) +{ + int ret; + + mutex_lock(&test_fw_mutex); + + __test_firmware_config_free(); + + ret = __test_firmware_config_init(); + if (ret < 0) { + ret = -ENOMEM; + pr_err("could not alloc settings for config trigger: %d\n", + ret); + goto out; + } + + pr_info("reset\n"); + ret = count; + +out: + mutex_unlock(&test_fw_mutex); + + return ret; +} +static DEVICE_ATTR_WO(reset); + +static ssize_t config_show(struct device *dev, + struct device_attribute *attr, + char *buf) +{ + int len = 0; + + mutex_lock(&test_fw_mutex); + + len += snprintf(buf, PAGE_SIZE, + "Custom trigger configuration for: %s\n", + dev_name(dev)); + + if (test_fw_config->name) + len += snprintf(buf+len, PAGE_SIZE, + "name:\t%s\n", + test_fw_config->name); + else + len += snprintf(buf+len, PAGE_SIZE, + "name:\tEMTPY\n"); + + len += snprintf(buf+len, PAGE_SIZE, + "num_requests:\t%u\n", test_fw_config->num_requests); + + len += snprintf(buf+len, PAGE_SIZE, + "send_uevent:\t\t%s\n", + test_fw_config->send_uevent ? + "FW_ACTION_HOTPLUG" : + "FW_ACTION_NOHOTPLUG"); + len += snprintf(buf+len, PAGE_SIZE, + "sync_direct:\t\t%s\n", + test_fw_config->sync_direct ? "true" : "false"); + len += snprintf(buf+len, PAGE_SIZE, + "read_fw_idx:\t%u\n", test_fw_config->read_fw_idx); + + mutex_unlock(&test_fw_mutex); + + return len; +} +static DEVICE_ATTR_RO(config); + +static ssize_t config_name_store(struct device *dev, + struct device_attribute *attr, + const char *buf, size_t count) +{ + int ret; + + mutex_lock(&test_fw_mutex); + kfree_const(test_fw_config->name); + ret = __kstrncpy(&test_fw_config->name, buf, count, GFP_KERNEL); + mutex_unlock(&test_fw_mutex); + + return ret; +} + +/* + * As per sysfs_kf_seq_show() the buf is max PAGE_SIZE. + */ +static ssize_t config_test_show_str(char *dst, + char *src) +{ + int len; + + mutex_lock(&test_fw_mutex); + len = snprintf(dst, PAGE_SIZE, "%s\n", src); + mutex_unlock(&test_fw_mutex); + + return len; +} + +static int test_dev_config_update_bool(const char *buf, size_t size, + bool *cfg) +{ + int ret; + + mutex_lock(&test_fw_mutex); + if (strtobool(buf, cfg) < 0) + ret = -EINVAL; + else + ret = size; + mutex_unlock(&test_fw_mutex); + + return ret; +} + +static ssize_t +test_dev_config_show_bool(char *buf, + bool config) +{ + bool val; + + mutex_lock(&test_fw_mutex); + val = config; + mutex_unlock(&test_fw_mutex); + + return snprintf(buf, PAGE_SIZE, "%d\n", val); +} + +static ssize_t test_dev_config_show_int(char *buf, int cfg) +{ + int val; + + mutex_lock(&test_fw_mutex); + val = cfg; + mutex_unlock(&test_fw_mutex); + + return snprintf(buf, PAGE_SIZE, "%d\n", val); +} + +static int test_dev_config_update_u8(const char *buf, size_t size, u8 *cfg) +{ + int ret; + long new; + + ret = kstrtol(buf, 10, &new); + if (ret) + return ret; + + if (new > U8_MAX) + return -EINVAL; + + mutex_lock(&test_fw_mutex); + *(u8 *)cfg = new; + mutex_unlock(&test_fw_mutex); + + /* Always return full write size even if we didn't consume all */ + return size; +} + +static ssize_t test_dev_config_show_u8(char *buf, u8 cfg) +{ + u8 val; + + mutex_lock(&test_fw_mutex); + val = cfg; + mutex_unlock(&test_fw_mutex); + + return snprintf(buf, PAGE_SIZE, "%u\n", val); +} + +static ssize_t config_name_show(struct device *dev, + struct device_attribute *attr, + char *buf) +{ + return config_test_show_str(buf, test_fw_config->name); +} +static DEVICE_ATTR(config_name, 0644, config_name_show, config_name_store); + +static ssize_t config_num_requests_store(struct device *dev, + struct device_attribute *attr, + const char *buf, size_t count) +{ + int rc; + + mutex_lock(&test_fw_mutex); + if (test_fw_config->reqs) { + pr_err("Must call release_all_firmware prior to changing config\n"); + rc = -EINVAL; + goto out; + } + mutex_unlock(&test_fw_mutex); + + rc = test_dev_config_update_u8(buf, count, + &test_fw_config->num_requests); + +out: + return rc; +} + +static ssize_t config_num_requests_show(struct device *dev, + struct device_attribute *attr, + char *buf) +{ + return test_dev_config_show_u8(buf, test_fw_config->num_requests); +} +static DEVICE_ATTR(config_num_requests, 0644, config_num_requests_show, + config_num_requests_store); + +static ssize_t config_sync_direct_store(struct device *dev, + struct device_attribute *attr, + const char *buf, size_t count) +{ + int rc = test_dev_config_update_bool(buf, count, + &test_fw_config->sync_direct); + + if (rc == count) + test_fw_config->req_firmware = test_fw_config->sync_direct ? + request_firmware_direct : + request_firmware; + return rc; +} + +static ssize_t config_sync_direct_show(struct device *dev, + struct device_attribute *attr, + char *buf) +{ + return test_dev_config_show_bool(buf, test_fw_config->sync_direct); +} +static DEVICE_ATTR(config_sync_direct, 0644, config_sync_direct_show, + config_sync_direct_store); + +static ssize_t config_send_uevent_store(struct device *dev, + struct device_attribute *attr, + const char *buf, size_t count) +{ + return test_dev_config_update_bool(buf, count, + &test_fw_config->send_uevent); +} + +static ssize_t config_send_uevent_show(struct device *dev, + struct device_attribute *attr, + char *buf) +{ + return test_dev_config_show_bool(buf, test_fw_config->send_uevent); +} +static DEVICE_ATTR(config_send_uevent, 0644, config_send_uevent_show, + config_send_uevent_store); + +static ssize_t config_read_fw_idx_store(struct device *dev, + struct device_attribute *attr, + const char *buf, size_t count) +{ + return test_dev_config_update_u8(buf, count, + &test_fw_config->read_fw_idx); +} + +static ssize_t config_read_fw_idx_show(struct device *dev, + struct device_attribute *attr, + char *buf) +{ + return test_dev_config_show_u8(buf, test_fw_config->read_fw_idx); +} +static DEVICE_ATTR(config_read_fw_idx, 0644, config_read_fw_idx_show, + config_read_fw_idx_store); + + static ssize_t trigger_request_store(struct device *dev, struct device_attribute *attr, const char *buf, size_t count) @@ -170,12 +577,301 @@ out: } static DEVICE_ATTR_WO(trigger_custom_fallback); +static int test_fw_run_batch_request(void *data) +{ + struct test_batched_req *req = data; + + if (!req) { + test_fw_config->test_result = -EINVAL; + return -EINVAL; + } + + req->rc = test_fw_config->req_firmware(&req->fw, req->name, req->dev); + if (req->rc) { + pr_info("#%u: batched sync load failed: %d\n", + req->idx, req->rc); + if (!test_fw_config->test_result) + test_fw_config->test_result = req->rc; + } else if (req->fw) { + req->sent = true; + pr_info("#%u: batched sync loaded %zu\n", + req->idx, req->fw->size); + } + complete(&req->completion); + + req->task = NULL; + + return 0; +} + +/* + * We use a kthread as otherwise the kernel serializes all our sync requests + * and we would not be able to mimic batched requests on a sync call. Batched + * requests on a sync call can for instance happen on a device driver when + * multiple cards are used and firmware loading happens outside of probe. + */ +static ssize_t trigger_batched_requests_store(struct device *dev, + struct device_attribute *attr, + const char *buf, size_t count) +{ + struct test_batched_req *req; + int rc; + u8 i; + + mutex_lock(&test_fw_mutex); + + test_fw_config->reqs = vzalloc(sizeof(struct test_batched_req) * + test_fw_config->num_requests * 2); + if (!test_fw_config->reqs) { + rc = -ENOMEM; + goto out_unlock; + } + + pr_info("batched sync firmware loading '%s' %u times\n", + test_fw_config->name, test_fw_config->num_requests); + + for (i = 0; i < test_fw_config->num_requests; i++) { + req = &test_fw_config->reqs[i]; + if (!req) { + WARN_ON(1); + rc = -ENOMEM; + goto out_bail; + } + req->fw = NULL; + req->idx = i; + req->name = test_fw_config->name; + req->dev = dev; + init_completion(&req->completion); + req->task = kthread_run(test_fw_run_batch_request, req, + "%s-%u", KBUILD_MODNAME, req->idx); + if (!req->task || IS_ERR(req->task)) { + pr_err("Setting up thread %u failed\n", req->idx); + req->task = NULL; + rc = -ENOMEM; + goto out_bail; + } + } + + rc = count; + + /* + * We require an explicit release to enable more time and delay of + * calling release_firmware() to improve our chances of forcing a + * batched request. If we instead called release_firmware() right away + * then we might miss on an opportunity of having a successful firmware + * request pass on the opportunity to be come a batched request. + */ + +out_bail: + for (i = 0; i < test_fw_config->num_requests; i++) { + req = &test_fw_config->reqs[i]; + if (req->task || req->sent) + wait_for_completion(&req->completion); + } + + /* Override any worker error if we had a general setup error */ + if (rc < 0) + test_fw_config->test_result = rc; + +out_unlock: + mutex_unlock(&test_fw_mutex); + + return rc; +} +static DEVICE_ATTR_WO(trigger_batched_requests); + +/* + * We wait for each callback to return with the lock held, no need to lock here + */ +static void trigger_batched_cb(const struct firmware *fw, void *context) +{ + struct test_batched_req *req = context; + + if (!req) { + test_fw_config->test_result = -EINVAL; + return; + } + + /* forces *some* batched requests to queue up */ + if (!req->idx) + ssleep(2); + + req->fw = fw; + + /* + * Unfortunately the firmware API gives us nothing other than a null FW + * if the firmware was not found on async requests. Best we can do is + * just assume -ENOENT. A better API would pass the actual return + * value to the callback. + */ + if (!fw && !test_fw_config->test_result) + test_fw_config->test_result = -ENOENT; + + complete(&req->completion); +} + +static +ssize_t trigger_batched_requests_async_store(struct device *dev, + struct device_attribute *attr, + const char *buf, size_t count) +{ + struct test_batched_req *req; + bool send_uevent; + int rc; + u8 i; + + mutex_lock(&test_fw_mutex); + + test_fw_config->reqs = vzalloc(sizeof(struct test_batched_req) * + test_fw_config->num_requests * 2); + if (!test_fw_config->reqs) { + rc = -ENOMEM; + goto out; + } + + pr_info("batched loading '%s' custom fallback mechanism %u times\n", + test_fw_config->name, test_fw_config->num_requests); + + send_uevent = test_fw_config->send_uevent ? FW_ACTION_HOTPLUG : + FW_ACTION_NOHOTPLUG; + + for (i = 0; i < test_fw_config->num_requests; i++) { + req = &test_fw_config->reqs[i]; + if (!req) { + WARN_ON(1); + goto out_bail; + } + req->name = test_fw_config->name; + req->fw = NULL; + req->idx = i; + init_completion(&req->completion); + rc = request_firmware_nowait(THIS_MODULE, send_uevent, + req->name, + dev, GFP_KERNEL, req, + trigger_batched_cb); + if (rc) { + pr_info("#%u: batched async load failed setup: %d\n", + i, rc); + req->rc = rc; + goto out_bail; + } else + req->sent = true; + } + + rc = count; + +out_bail: + + /* + * We require an explicit release to enable more time and delay of + * calling release_firmware() to improve our chances of forcing a + * batched request. If we instead called release_firmware() right away + * then we might miss on an opportunity of having a successful firmware + * request pass on the opportunity to be come a batched request. + */ + + for (i = 0; i < test_fw_config->num_requests; i++) { + req = &test_fw_config->reqs[i]; + if (req->sent) + wait_for_completion(&req->completion); + } + + /* Override any worker error if we had a general setup error */ + if (rc < 0) + test_fw_config->test_result = rc; + +out: + mutex_unlock(&test_fw_mutex); + + return rc; +} +static DEVICE_ATTR_WO(trigger_batched_requests_async); + +static ssize_t test_result_show(struct device *dev, + struct device_attribute *attr, + char *buf) +{ + return test_dev_config_show_int(buf, test_fw_config->test_result); +} +static DEVICE_ATTR_RO(test_result); + +static ssize_t release_all_firmware_store(struct device *dev, + struct device_attribute *attr, + const char *buf, size_t count) +{ + test_release_all_firmware(); + return count; +} +static DEVICE_ATTR_WO(release_all_firmware); + +static ssize_t read_firmware_show(struct device *dev, + struct device_attribute *attr, + char *buf) +{ + struct test_batched_req *req; + u8 idx; + ssize_t rc = 0; + + mutex_lock(&test_fw_mutex); + + idx = test_fw_config->read_fw_idx; + if (idx >= test_fw_config->num_requests) { + rc = -ERANGE; + goto out; + } + + if (!test_fw_config->reqs) { + rc = -EINVAL; + goto out; + } + + req = &test_fw_config->reqs[idx]; + if (!req->fw) { + pr_err("#%u: failed to async load firmware\n", idx); + rc = -ENOENT; + goto out; + } + + pr_info("#%u: loaded %zu\n", idx, req->fw->size); + + if (req->fw->size > PAGE_SIZE) { + pr_err("Testing interface must use PAGE_SIZE firmware for now\n"); + rc = -EINVAL; + } + memcpy(buf, req->fw->data, req->fw->size); + + rc = req->fw->size; +out: + mutex_unlock(&test_fw_mutex); + + return rc; +} +static DEVICE_ATTR_RO(read_firmware); + #define TEST_FW_DEV_ATTR(name) &dev_attr_##name.attr static struct attribute *test_dev_attrs[] = { + TEST_FW_DEV_ATTR(reset), + + TEST_FW_DEV_ATTR(config), + TEST_FW_DEV_ATTR(config_name), + TEST_FW_DEV_ATTR(config_num_requests), + TEST_FW_DEV_ATTR(config_sync_direct), + TEST_FW_DEV_ATTR(config_send_uevent), + TEST_FW_DEV_ATTR(config_read_fw_idx), + + /* These don't use the config at all - they could be ported! */ TEST_FW_DEV_ATTR(trigger_request), TEST_FW_DEV_ATTR(trigger_async_request), TEST_FW_DEV_ATTR(trigger_custom_fallback), + + /* These use the config and can use the test_result */ + TEST_FW_DEV_ATTR(trigger_batched_requests), + TEST_FW_DEV_ATTR(trigger_batched_requests_async), + + TEST_FW_DEV_ATTR(release_all_firmware), + TEST_FW_DEV_ATTR(test_result), + TEST_FW_DEV_ATTR(read_firmware), NULL, }; @@ -192,8 +888,17 @@ static int __init test_firmware_init(void) { int rc; + test_fw_config = kzalloc(sizeof(struct test_config), GFP_KERNEL); + if (!test_fw_config) + return -ENOMEM; + + rc = __test_firmware_config_init(); + if (rc) + return rc; + rc = misc_register(&test_fw_misc_device); if (rc) { + kfree(test_fw_config); pr_err("could not register misc device: %d\n", rc); return rc; } @@ -207,8 +912,13 @@ module_init(test_firmware_init); static void __exit test_firmware_exit(void) { + mutex_lock(&test_fw_mutex); release_firmware(test_firmware); misc_deregister(&test_fw_misc_device); + __test_firmware_config_free(); + kfree(test_fw_config); + mutex_unlock(&test_fw_mutex); + pr_warn("removed interface\n"); } diff --git a/tools/testing/selftests/firmware/fw_filesystem.sh b/tools/testing/selftests/firmware/fw_filesystem.sh index e35691239350..7d8fd2e3695a 100755 --- a/tools/testing/selftests/firmware/fw_filesystem.sh +++ b/tools/testing/selftests/firmware/fw_filesystem.sh @@ -25,8 +25,9 @@ if [ ! -d $DIR ]; then fi # CONFIG_FW_LOADER_USER_HELPER has a sysfs class under /sys/class/firmware/ -# These days no one enables CONFIG_FW_LOADER_USER_HELPER so check for that -# as an indicator for CONFIG_FW_LOADER_USER_HELPER. +# These days most distros enable CONFIG_FW_LOADER_USER_HELPER but disable +# CONFIG_FW_LOADER_USER_HELPER_FALLBACK. We use /sys/class/firmware/ as an +# indicator for CONFIG_FW_LOADER_USER_HELPER. HAS_FW_LOADER_USER_HELPER=$(if [ -d /sys/class/firmware/ ]; then echo yes; else echo no; fi) if [ "$HAS_FW_LOADER_USER_HELPER" = "yes" ]; then @@ -116,4 +117,240 @@ else echo "$0: async filesystem loading works" fi +### Batched requests tests +test_config_present() +{ + if [ ! -f $DIR/reset ]; then + echo "Configuration triggers not present, ignoring test" + exit 0 + fi +} + +# Defaults : +# +# send_uevent: 1 +# sync_direct: 0 +# name: test-firmware.bin +# num_requests: 4 +config_reset() +{ + echo 1 > $DIR/reset +} + +release_all_firmware() +{ + echo 1 > $DIR/release_all_firmware +} + +config_set_name() +{ + echo -n $1 > $DIR/config_name +} + +config_set_sync_direct() +{ + echo 1 > $DIR/config_sync_direct +} + +config_unset_sync_direct() +{ + echo 0 > $DIR/config_sync_direct +} + +config_set_uevent() +{ + echo 1 > $DIR/config_send_uevent +} + +config_unset_uevent() +{ + echo 0 > $DIR/config_send_uevent +} + +config_trigger_sync() +{ + echo -n 1 > $DIR/trigger_batched_requests 2>/dev/null +} + +config_trigger_async() +{ + echo -n 1 > $DIR/trigger_batched_requests_async 2> /dev/null +} + +config_set_read_fw_idx() +{ + echo -n $1 > $DIR/config_read_fw_idx 2> /dev/null +} + +read_firmwares() +{ + for i in $(seq 0 3); do + config_set_read_fw_idx $i + # Verify the contents are what we expect. + # -Z required for now -- check for yourself, md5sum + # on $FW and DIR/read_firmware will yield the same. Even + # cmp agrees, so something is off. + if ! diff -q -Z "$FW" $DIR/read_firmware 2>/dev/null ; then + echo "request #$i: firmware was not loaded" >&2 + exit 1 + fi + done +} + +read_firmwares_expect_nofile() +{ + for i in $(seq 0 3); do + config_set_read_fw_idx $i + # Ensures contents differ + if diff -q -Z "$FW" $DIR/read_firmware 2>/dev/null ; then + echo "request $i: file was not expected to match" >&2 + exit 1 + fi + done +} + +test_batched_request_firmware_nofile() +{ + echo -n "Batched request_firmware() nofile try #$1: " + config_reset + config_set_name nope-test-firmware.bin + config_trigger_sync + read_firmwares_expect_nofile + release_all_firmware + echo "OK" +} + +test_batched_request_firmware_direct_nofile() +{ + echo -n "Batched request_firmware_direct() nofile try #$1: " + config_reset + config_set_name nope-test-firmware.bin + config_set_sync_direct + config_trigger_sync + release_all_firmware + echo "OK" +} + +test_request_firmware_nowait_uevent_nofile() +{ + echo -n "Batched request_firmware_nowait(uevent=true) nofile try #$1: " + config_reset + config_set_name nope-test-firmware.bin + config_trigger_async + release_all_firmware + echo "OK" +} + +test_wait_and_cancel_custom_load() +{ + if [ "$HAS_FW_LOADER_USER_HELPER" != "yes" ]; then + return + fi + local timeout=10 + name=$1 + while [ ! -e "$DIR"/"$name"/loading ]; do + sleep 0.1 + timeout=$(( $timeout - 1 )) + if [ "$timeout" -eq 0 ]; then + echo "firmware interface never appeared:" >&2 + echo "$DIR/$name/loading" >&2 + exit 1 + fi + done + echo -1 >"$DIR"/"$name"/loading +} + +test_request_firmware_nowait_custom_nofile() +{ + echo -n "Batched request_firmware_nowait(uevent=false) nofile try #$1: " + config_unset_uevent + config_set_name nope-test-firmware.bin + config_trigger_async & + test_wait_and_cancel_custom_load nope-test-firmware.bin + wait + release_all_firmware + echo "OK" +} + +test_batched_request_firmware() +{ + echo -n "Batched request_firmware() try #$1: " + config_reset + config_trigger_sync + read_firmwares + release_all_firmware + echo "OK" +} + +test_batched_request_firmware_direct() +{ + echo -n "Batched request_firmware_direct() try #$1: " + config_reset + config_set_sync_direct + config_trigger_sync + release_all_firmware + echo "OK" +} + +test_request_firmware_nowait_uevent() +{ + echo -n "Batched request_firmware_nowait(uevent=true) try #$1: " + config_reset + config_trigger_async + release_all_firmware + echo "OK" +} + +test_request_firmware_nowait_custom() +{ + echo -n "Batched request_firmware_nowait(uevent=false) try #$1: " + config_unset_uevent + config_trigger_async + release_all_firmware + echo "OK" +} + +# Only continue if batched request triggers are present on the +# test-firmware driver +test_config_present + +# test with the file present +echo +echo "Testing with the file present..." +for i in $(seq 1 5); do + test_batched_request_firmware $i +done + +for i in $(seq 1 5); do + test_batched_request_firmware_direct $i +done + +for i in $(seq 1 5); do + test_request_firmware_nowait_uevent $i +done + +for i in $(seq 1 5); do + test_request_firmware_nowait_custom $i +done + +# Test for file not found, errors are expected, the failure would be +# a hung task, which would require a hard reset. +echo +echo "Testing with the file missing..." +for i in $(seq 1 5); do + test_batched_request_firmware_nofile $i +done + +for i in $(seq 1 5); do + test_batched_request_firmware_direct_nofile $i +done + +for i in $(seq 1 5); do + test_request_firmware_nowait_uevent_nofile $i +done + +for i in $(seq 1 5); do + test_request_firmware_nowait_custom_nofile $i +done + exit 0 -- cgit v1.2.3 From caba4cbbd27d755572730801ac34fe063fc40a32 Mon Sep 17 00:00:00 2001 From: Waiman Long Date: Mon, 14 Aug 2017 09:52:13 -0400 Subject: debugobjects: Make kmemleak ignore debug objects The allocated debug objects are either on the free list or in the hashed bucket lists. So they won't get lost. However if both debug objects and kmemleak are enabled and kmemleak scanning is done while some of the debug objects are transitioning from one list to the others, false negative reporting of memory leaks may happen for those objects. For example, [38687.275678] kmemleak: 12 new suspected memory leaks (see /sys/kernel/debug/kmemleak) unreferenced object 0xffff92e98aabeb68 (size 40): comm "ksmtuned", pid 4344, jiffies 4298403600 (age 906.430s) hex dump (first 32 bytes): 00 00 00 00 00 00 00 00 d0 bc db 92 e9 92 ff ff ................ 01 00 00 00 00 00 00 00 38 36 8a 61 e9 92 ff ff ........86.a.... backtrace: [] kmemleak_alloc+0x4a/0xa0 [] kmem_cache_alloc+0xe9/0x320 [] __debug_object_init+0x3e6/0x400 [] debug_object_activate+0x131/0x210 [] __call_rcu+0x3f/0x400 [] call_rcu_sched+0x1d/0x20 [] put_object+0x2c/0x40 [] __delete_object+0x3c/0x50 [] delete_object_full+0x1d/0x20 [] kmemleak_free+0x32/0x80 [] kmem_cache_free+0x77/0x350 [] unlink_anon_vmas+0x82/0x1e0 [] free_pgtables+0xa1/0x110 [] exit_mmap+0xc1/0x170 [] mmput+0x80/0x150 [] do_exit+0x2a9/0xd20 The references in the debug objects may also hide a real memory leak. As there is no point in having kmemleak to track debug object allocations, kmemleak checking is now disabled for debug objects. Signed-off-by: Waiman Long Signed-off-by: Thomas Gleixner Cc: Andrew Morton Link: http://lkml.kernel.org/r/1502718733-8527-1-git-send-email-longman@redhat.com --- init/main.c | 2 +- lib/debugobjects.c | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) (limited to 'lib') diff --git a/init/main.c b/init/main.c index 052481fbe363..7ec20cf7b111 100644 --- a/init/main.c +++ b/init/main.c @@ -651,8 +651,8 @@ asmlinkage __visible void __init start_kernel(void) } #endif page_ext_init(); - debug_objects_mem_init(); kmemleak_init(); + debug_objects_mem_init(); setup_per_cpu_pageset(); numa_policy_init(); if (late_time_init) diff --git a/lib/debugobjects.c b/lib/debugobjects.c index 17afb0430161..2f5349c6e81a 100644 --- a/lib/debugobjects.c +++ b/lib/debugobjects.c @@ -18,6 +18,7 @@ #include #include #include +#include #define ODEBUG_HASH_BITS 14 #define ODEBUG_HASH_SIZE (1 << ODEBUG_HASH_BITS) @@ -110,6 +111,7 @@ static void fill_pool(void) if (!new) return; + kmemleak_ignore(new); raw_spin_lock_irqsave(&pool_lock, flags); hlist_add_head(&new->node, &obj_pool); debug_objects_allocated++; @@ -1080,6 +1082,7 @@ static int __init debug_objects_replace_static_objects(void) obj = kmem_cache_zalloc(obj_cache, GFP_KERNEL); if (!obj) goto free; + kmemleak_ignore(obj); hlist_add_head(&obj->node, &objects); } -- cgit v1.2.3 From 5d2405227a9eaea48e8cc95756a06d407b11f141 Mon Sep 17 00:00:00 2001 From: Nick Terrell Date: Fri, 4 Aug 2017 13:19:17 -0700 Subject: lib: Add xxhash module Adds xxhash kernel module with xxh32 and xxh64 hashes. xxhash is an extremely fast non-cryptographic hash algorithm for checksumming. The zstd compression and decompression modules added in the next patch require xxhash. I extracted it out from zstd since it is useful on its own. I copied the code from the upstream XXHash source repository and translated it into kernel style. I ran benchmarks and tests in the kernel and tests in userland. I benchmarked xxhash as a special character device. I ran in four modes, no-op, xxh32, xxh64, and crc32. The no-op mode simply copies the data to kernel space and ignores it. The xxh32, xxh64, and crc32 modes compute hashes on the copied data. I also ran it with four different buffer sizes. The benchmark file is located in the upstream zstd source repository under `contrib/linux-kernel/xxhash_test.c` [1]. I ran the benchmarks on a Ubuntu 14.04 VM with 2 cores and 4 GiB of RAM. The VM is running on a MacBook Pro with a 3.1 GHz Intel Core i7 processor, 16 GB of RAM, and a SSD. I benchmarked using the file `filesystem.squashfs` from `ubuntu-16.10-desktop-amd64.iso`, which is 1,536,217,088 B large. Run the following commands for the benchmark: modprobe xxhash_test mknod xxhash_test c 245 0 time cp filesystem.squashfs xxhash_test The time is reported by the time of the userland `cp`. The GB/s is computed with 1,536,217,008 B / time(buffer size, hash) which includes the time to copy from userland. The Normalized GB/s is computed with 1,536,217,088 B / (time(buffer size, hash) - time(buffer size, none)). | Buffer Size (B) | Hash | Time (s) | GB/s | Adjusted GB/s | |-----------------|-------|----------|------|---------------| | 1024 | none | 0.408 | 3.77 | - | | 1024 | xxh32 | 0.649 | 2.37 | 6.37 | | 1024 | xxh64 | 0.542 | 2.83 | 11.46 | | 1024 | crc32 | 1.290 | 1.19 | 1.74 | | 4096 | none | 0.380 | 4.04 | - | | 4096 | xxh32 | 0.645 | 2.38 | 5.79 | | 4096 | xxh64 | 0.500 | 3.07 | 12.80 | | 4096 | crc32 | 1.168 | 1.32 | 1.95 | | 8192 | none | 0.351 | 4.38 | - | | 8192 | xxh32 | 0.614 | 2.50 | 5.84 | | 8192 | xxh64 | 0.464 | 3.31 | 13.60 | | 8192 | crc32 | 1.163 | 1.32 | 1.89 | | 16384 | none | 0.346 | 4.43 | - | | 16384 | xxh32 | 0.590 | 2.60 | 6.30 | | 16384 | xxh64 | 0.466 | 3.30 | 12.80 | | 16384 | crc32 | 1.183 | 1.30 | 1.84 | Tested in userland using the test-suite in the zstd repo under `contrib/linux-kernel/test/XXHashUserlandTest.cpp` [2] by mocking the kernel functions. A line in each branch of every function in `xxhash.c` was commented out to ensure that the test-suite fails. Additionally tested while testing zstd and with SMHasher [3]. [1] https://phabricator.intern.facebook.com/P57526246 [2] https://github.com/facebook/zstd/blob/dev/contrib/linux-kernel/test/XXHashUserlandTest.cpp [3] https://github.com/aappleby/smhasher zstd source repository: https://github.com/facebook/zstd XXHash source repository: https://github.com/cyan4973/xxhash Signed-off-by: Nick Terrell Signed-off-by: Chris Mason --- include/linux/xxhash.h | 236 +++++++++++++++++++++++ lib/Kconfig | 3 + lib/Makefile | 1 + lib/xxhash.c | 500 +++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 740 insertions(+) create mode 100644 include/linux/xxhash.h create mode 100644 lib/xxhash.c (limited to 'lib') diff --git a/include/linux/xxhash.h b/include/linux/xxhash.h new file mode 100644 index 000000000000..9e1f42cb57e9 --- /dev/null +++ b/include/linux/xxhash.h @@ -0,0 +1,236 @@ +/* + * xxHash - Extremely Fast Hash algorithm + * Copyright (C) 2012-2016, Yann Collet. + * + * BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php) + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * This program is free software; you can redistribute it and/or modify it under + * the terms of the GNU General Public License version 2 as published by the + * Free Software Foundation. This program is dual-licensed; you may select + * either version 2 of the GNU General Public License ("GPL") or BSD license + * ("BSD"). + * + * You can contact the author at: + * - xxHash homepage: http://cyan4973.github.io/xxHash/ + * - xxHash source repository: https://github.com/Cyan4973/xxHash + */ + +/* + * Notice extracted from xxHash homepage: + * + * xxHash is an extremely fast Hash algorithm, running at RAM speed limits. + * It also successfully passes all tests from the SMHasher suite. + * + * Comparison (single thread, Windows Seven 32 bits, using SMHasher on a Core 2 + * Duo @3GHz) + * + * Name Speed Q.Score Author + * xxHash 5.4 GB/s 10 + * CrapWow 3.2 GB/s 2 Andrew + * MumurHash 3a 2.7 GB/s 10 Austin Appleby + * SpookyHash 2.0 GB/s 10 Bob Jenkins + * SBox 1.4 GB/s 9 Bret Mulvey + * Lookup3 1.2 GB/s 9 Bob Jenkins + * SuperFastHash 1.2 GB/s 1 Paul Hsieh + * CityHash64 1.05 GB/s 10 Pike & Alakuijala + * FNV 0.55 GB/s 5 Fowler, Noll, Vo + * CRC32 0.43 GB/s 9 + * MD5-32 0.33 GB/s 10 Ronald L. Rivest + * SHA1-32 0.28 GB/s 10 + * + * Q.Score is a measure of quality of the hash function. + * It depends on successfully passing SMHasher test set. + * 10 is a perfect score. + * + * A 64-bits version, named xxh64 offers much better speed, + * but for 64-bits applications only. + * Name Speed on 64 bits Speed on 32 bits + * xxh64 13.8 GB/s 1.9 GB/s + * xxh32 6.8 GB/s 6.0 GB/s + */ + +#ifndef XXHASH_H +#define XXHASH_H + +#include + +/*-**************************** + * Simple Hash Functions + *****************************/ + +/** + * xxh32() - calculate the 32-bit hash of the input with a given seed. + * + * @input: The data to hash. + * @length: The length of the data to hash. + * @seed: The seed can be used to alter the result predictably. + * + * Speed on Core 2 Duo @ 3 GHz (single thread, SMHasher benchmark) : 5.4 GB/s + * + * Return: The 32-bit hash of the data. + */ +uint32_t xxh32(const void *input, size_t length, uint32_t seed); + +/** + * xxh64() - calculate the 64-bit hash of the input with a given seed. + * + * @input: The data to hash. + * @length: The length of the data to hash. + * @seed: The seed can be used to alter the result predictably. + * + * This function runs 2x faster on 64-bit systems, but slower on 32-bit systems. + * + * Return: The 64-bit hash of the data. + */ +uint64_t xxh64(const void *input, size_t length, uint64_t seed); + +/*-**************************** + * Streaming Hash Functions + *****************************/ + +/* + * These definitions are only meant to allow allocation of XXH state + * statically, on stack, or in a struct for example. + * Do not use members directly. + */ + +/** + * struct xxh32_state - private xxh32 state, do not use members directly + */ +struct xxh32_state { + uint32_t total_len_32; + uint32_t large_len; + uint32_t v1; + uint32_t v2; + uint32_t v3; + uint32_t v4; + uint32_t mem32[4]; + uint32_t memsize; +}; + +/** + * struct xxh32_state - private xxh64 state, do not use members directly + */ +struct xxh64_state { + uint64_t total_len; + uint64_t v1; + uint64_t v2; + uint64_t v3; + uint64_t v4; + uint64_t mem64[4]; + uint32_t memsize; +}; + +/** + * xxh32_reset() - reset the xxh32 state to start a new hashing operation + * + * @state: The xxh32 state to reset. + * @seed: Initialize the hash state with this seed. + * + * Call this function on any xxh32_state to prepare for a new hashing operation. + */ +void xxh32_reset(struct xxh32_state *state, uint32_t seed); + +/** + * xxh32_update() - hash the data given and update the xxh32 state + * + * @state: The xxh32 state to update. + * @input: The data to hash. + * @length: The length of the data to hash. + * + * After calling xxh32_reset() call xxh32_update() as many times as necessary. + * + * Return: Zero on success, otherwise an error code. + */ +int xxh32_update(struct xxh32_state *state, const void *input, size_t length); + +/** + * xxh32_digest() - produce the current xxh32 hash + * + * @state: Produce the current xxh32 hash of this state. + * + * A hash value can be produced at any time. It is still possible to continue + * inserting input into the hash state after a call to xxh32_digest(), and + * generate new hashes later on, by calling xxh32_digest() again. + * + * Return: The xxh32 hash stored in the state. + */ +uint32_t xxh32_digest(const struct xxh32_state *state); + +/** + * xxh64_reset() - reset the xxh64 state to start a new hashing operation + * + * @state: The xxh64 state to reset. + * @seed: Initialize the hash state with this seed. + */ +void xxh64_reset(struct xxh64_state *state, uint64_t seed); + +/** + * xxh64_update() - hash the data given and update the xxh64 state + * @state: The xxh64 state to update. + * @input: The data to hash. + * @length: The length of the data to hash. + * + * After calling xxh64_reset() call xxh64_update() as many times as necessary. + * + * Return: Zero on success, otherwise an error code. + */ +int xxh64_update(struct xxh64_state *state, const void *input, size_t length); + +/** + * xxh64_digest() - produce the current xxh64 hash + * + * @state: Produce the current xxh64 hash of this state. + * + * A hash value can be produced at any time. It is still possible to continue + * inserting input into the hash state after a call to xxh64_digest(), and + * generate new hashes later on, by calling xxh64_digest() again. + * + * Return: The xxh64 hash stored in the state. + */ +uint64_t xxh64_digest(const struct xxh64_state *state); + +/*-************************** + * Utils + ***************************/ + +/** + * xxh32_copy_state() - copy the source state into the destination state + * + * @src: The source xxh32 state. + * @dst: The destination xxh32 state. + */ +void xxh32_copy_state(struct xxh32_state *dst, const struct xxh32_state *src); + +/** + * xxh64_copy_state() - copy the source state into the destination state + * + * @src: The source xxh64 state. + * @dst: The destination xxh64 state. + */ +void xxh64_copy_state(struct xxh64_state *dst, const struct xxh64_state *src); + +#endif /* XXHASH_H */ diff --git a/lib/Kconfig b/lib/Kconfig index 6762529ad9e4..5e7541f46081 100644 --- a/lib/Kconfig +++ b/lib/Kconfig @@ -192,6 +192,9 @@ config CRC8 when they need to do cyclic redundancy check according CRC8 algorithm. Module will be called crc8. +config XXHASH + tristate + config AUDIT_GENERIC bool depends on AUDIT && !AUDIT_ARCH diff --git a/lib/Makefile b/lib/Makefile index 40c18372b301..d06b68ae55a3 100644 --- a/lib/Makefile +++ b/lib/Makefile @@ -102,6 +102,7 @@ obj-$(CONFIG_CRC4) += crc4.o obj-$(CONFIG_CRC7) += crc7.o obj-$(CONFIG_LIBCRC32C) += libcrc32c.o obj-$(CONFIG_CRC8) += crc8.o +obj-$(CONFIG_XXHASH) += xxhash.o obj-$(CONFIG_GENERIC_ALLOCATOR) += genalloc.o obj-$(CONFIG_842_COMPRESS) += 842/ diff --git a/lib/xxhash.c b/lib/xxhash.c new file mode 100644 index 000000000000..aa61e2a3802f --- /dev/null +++ b/lib/xxhash.c @@ -0,0 +1,500 @@ +/* + * xxHash - Extremely Fast Hash algorithm + * Copyright (C) 2012-2016, Yann Collet. + * + * BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php) + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * This program is free software; you can redistribute it and/or modify it under + * the terms of the GNU General Public License version 2 as published by the + * Free Software Foundation. This program is dual-licensed; you may select + * either version 2 of the GNU General Public License ("GPL") or BSD license + * ("BSD"). + * + * You can contact the author at: + * - xxHash homepage: http://cyan4973.github.io/xxHash/ + * - xxHash source repository: https://github.com/Cyan4973/xxHash + */ + +#include +#include +#include +#include +#include +#include +#include + +/*-************************************* + * Macros + **************************************/ +#define xxh_rotl32(x, r) ((x << r) | (x >> (32 - r))) +#define xxh_rotl64(x, r) ((x << r) | (x >> (64 - r))) + +#ifdef __LITTLE_ENDIAN +# define XXH_CPU_LITTLE_ENDIAN 1 +#else +# define XXH_CPU_LITTLE_ENDIAN 0 +#endif + +/*-************************************* + * Constants + **************************************/ +static const uint32_t PRIME32_1 = 2654435761U; +static const uint32_t PRIME32_2 = 2246822519U; +static const uint32_t PRIME32_3 = 3266489917U; +static const uint32_t PRIME32_4 = 668265263U; +static const uint32_t PRIME32_5 = 374761393U; + +static const uint64_t PRIME64_1 = 11400714785074694791ULL; +static const uint64_t PRIME64_2 = 14029467366897019727ULL; +static const uint64_t PRIME64_3 = 1609587929392839161ULL; +static const uint64_t PRIME64_4 = 9650029242287828579ULL; +static const uint64_t PRIME64_5 = 2870177450012600261ULL; + +/*-************************** + * Utils + ***************************/ +void xxh32_copy_state(struct xxh32_state *dst, const struct xxh32_state *src) +{ + memcpy(dst, src, sizeof(*dst)); +} +EXPORT_SYMBOL(xxh32_copy_state); + +void xxh64_copy_state(struct xxh64_state *dst, const struct xxh64_state *src) +{ + memcpy(dst, src, sizeof(*dst)); +} +EXPORT_SYMBOL(xxh64_copy_state); + +/*-*************************** + * Simple Hash Functions + ****************************/ +static uint32_t xxh32_round(uint32_t seed, const uint32_t input) +{ + seed += input * PRIME32_2; + seed = xxh_rotl32(seed, 13); + seed *= PRIME32_1; + return seed; +} + +uint32_t xxh32(const void *input, const size_t len, const uint32_t seed) +{ + const uint8_t *p = (const uint8_t *)input; + const uint8_t *b_end = p + len; + uint32_t h32; + + if (len >= 16) { + const uint8_t *const limit = b_end - 16; + uint32_t v1 = seed + PRIME32_1 + PRIME32_2; + uint32_t v2 = seed + PRIME32_2; + uint32_t v3 = seed + 0; + uint32_t v4 = seed - PRIME32_1; + + do { + v1 = xxh32_round(v1, get_unaligned_le32(p)); + p += 4; + v2 = xxh32_round(v2, get_unaligned_le32(p)); + p += 4; + v3 = xxh32_round(v3, get_unaligned_le32(p)); + p += 4; + v4 = xxh32_round(v4, get_unaligned_le32(p)); + p += 4; + } while (p <= limit); + + h32 = xxh_rotl32(v1, 1) + xxh_rotl32(v2, 7) + + xxh_rotl32(v3, 12) + xxh_rotl32(v4, 18); + } else { + h32 = seed + PRIME32_5; + } + + h32 += (uint32_t)len; + + while (p + 4 <= b_end) { + h32 += get_unaligned_le32(p) * PRIME32_3; + h32 = xxh_rotl32(h32, 17) * PRIME32_4; + p += 4; + } + + while (p < b_end) { + h32 += (*p) * PRIME32_5; + h32 = xxh_rotl32(h32, 11) * PRIME32_1; + p++; + } + + h32 ^= h32 >> 15; + h32 *= PRIME32_2; + h32 ^= h32 >> 13; + h32 *= PRIME32_3; + h32 ^= h32 >> 16; + + return h32; +} +EXPORT_SYMBOL(xxh32); + +static uint64_t xxh64_round(uint64_t acc, const uint64_t input) +{ + acc += input * PRIME64_2; + acc = xxh_rotl64(acc, 31); + acc *= PRIME64_1; + return acc; +} + +static uint64_t xxh64_merge_round(uint64_t acc, uint64_t val) +{ + val = xxh64_round(0, val); + acc ^= val; + acc = acc * PRIME64_1 + PRIME64_4; + return acc; +} + +uint64_t xxh64(const void *input, const size_t len, const uint64_t seed) +{ + const uint8_t *p = (const uint8_t *)input; + const uint8_t *const b_end = p + len; + uint64_t h64; + + if (len >= 32) { + const uint8_t *const limit = b_end - 32; + uint64_t v1 = seed + PRIME64_1 + PRIME64_2; + uint64_t v2 = seed + PRIME64_2; + uint64_t v3 = seed + 0; + uint64_t v4 = seed - PRIME64_1; + + do { + v1 = xxh64_round(v1, get_unaligned_le64(p)); + p += 8; + v2 = xxh64_round(v2, get_unaligned_le64(p)); + p += 8; + v3 = xxh64_round(v3, get_unaligned_le64(p)); + p += 8; + v4 = xxh64_round(v4, get_unaligned_le64(p)); + p += 8; + } while (p <= limit); + + h64 = xxh_rotl64(v1, 1) + xxh_rotl64(v2, 7) + + xxh_rotl64(v3, 12) + xxh_rotl64(v4, 18); + h64 = xxh64_merge_round(h64, v1); + h64 = xxh64_merge_round(h64, v2); + h64 = xxh64_merge_round(h64, v3); + h64 = xxh64_merge_round(h64, v4); + + } else { + h64 = seed + PRIME64_5; + } + + h64 += (uint64_t)len; + + while (p + 8 <= b_end) { + const uint64_t k1 = xxh64_round(0, get_unaligned_le64(p)); + + h64 ^= k1; + h64 = xxh_rotl64(h64, 27) * PRIME64_1 + PRIME64_4; + p += 8; + } + + if (p + 4 <= b_end) { + h64 ^= (uint64_t)(get_unaligned_le32(p)) * PRIME64_1; + h64 = xxh_rotl64(h64, 23) * PRIME64_2 + PRIME64_3; + p += 4; + } + + while (p < b_end) { + h64 ^= (*p) * PRIME64_5; + h64 = xxh_rotl64(h64, 11) * PRIME64_1; + p++; + } + + h64 ^= h64 >> 33; + h64 *= PRIME64_2; + h64 ^= h64 >> 29; + h64 *= PRIME64_3; + h64 ^= h64 >> 32; + + return h64; +} +EXPORT_SYMBOL(xxh64); + +/*-************************************************** + * Advanced Hash Functions + ***************************************************/ +void xxh32_reset(struct xxh32_state *statePtr, const uint32_t seed) +{ + /* use a local state for memcpy() to avoid strict-aliasing warnings */ + struct xxh32_state state; + + memset(&state, 0, sizeof(state)); + state.v1 = seed + PRIME32_1 + PRIME32_2; + state.v2 = seed + PRIME32_2; + state.v3 = seed + 0; + state.v4 = seed - PRIME32_1; + memcpy(statePtr, &state, sizeof(state)); +} +EXPORT_SYMBOL(xxh32_reset); + +void xxh64_reset(struct xxh64_state *statePtr, const uint64_t seed) +{ + /* use a local state for memcpy() to avoid strict-aliasing warnings */ + struct xxh64_state state; + + memset(&state, 0, sizeof(state)); + state.v1 = seed + PRIME64_1 + PRIME64_2; + state.v2 = seed + PRIME64_2; + state.v3 = seed + 0; + state.v4 = seed - PRIME64_1; + memcpy(statePtr, &state, sizeof(state)); +} +EXPORT_SYMBOL(xxh64_reset); + +int xxh32_update(struct xxh32_state *state, const void *input, const size_t len) +{ + const uint8_t *p = (const uint8_t *)input; + const uint8_t *const b_end = p + len; + + if (input == NULL) + return -EINVAL; + + state->total_len_32 += (uint32_t)len; + state->large_len |= (len >= 16) | (state->total_len_32 >= 16); + + if (state->memsize + len < 16) { /* fill in tmp buffer */ + memcpy((uint8_t *)(state->mem32) + state->memsize, input, len); + state->memsize += (uint32_t)len; + return 0; + } + + if (state->memsize) { /* some data left from previous update */ + const uint32_t *p32 = state->mem32; + + memcpy((uint8_t *)(state->mem32) + state->memsize, input, + 16 - state->memsize); + + state->v1 = xxh32_round(state->v1, get_unaligned_le32(p32)); + p32++; + state->v2 = xxh32_round(state->v2, get_unaligned_le32(p32)); + p32++; + state->v3 = xxh32_round(state->v3, get_unaligned_le32(p32)); + p32++; + state->v4 = xxh32_round(state->v4, get_unaligned_le32(p32)); + p32++; + + p += 16-state->memsize; + state->memsize = 0; + } + + if (p <= b_end - 16) { + const uint8_t *const limit = b_end - 16; + uint32_t v1 = state->v1; + uint32_t v2 = state->v2; + uint32_t v3 = state->v3; + uint32_t v4 = state->v4; + + do { + v1 = xxh32_round(v1, get_unaligned_le32(p)); + p += 4; + v2 = xxh32_round(v2, get_unaligned_le32(p)); + p += 4; + v3 = xxh32_round(v3, get_unaligned_le32(p)); + p += 4; + v4 = xxh32_round(v4, get_unaligned_le32(p)); + p += 4; + } while (p <= limit); + + state->v1 = v1; + state->v2 = v2; + state->v3 = v3; + state->v4 = v4; + } + + if (p < b_end) { + memcpy(state->mem32, p, (size_t)(b_end-p)); + state->memsize = (uint32_t)(b_end-p); + } + + return 0; +} +EXPORT_SYMBOL(xxh32_update); + +uint32_t xxh32_digest(const struct xxh32_state *state) +{ + const uint8_t *p = (const uint8_t *)state->mem32; + const uint8_t *const b_end = (const uint8_t *)(state->mem32) + + state->memsize; + uint32_t h32; + + if (state->large_len) { + h32 = xxh_rotl32(state->v1, 1) + xxh_rotl32(state->v2, 7) + + xxh_rotl32(state->v3, 12) + xxh_rotl32(state->v4, 18); + } else { + h32 = state->v3 /* == seed */ + PRIME32_5; + } + + h32 += state->total_len_32; + + while (p + 4 <= b_end) { + h32 += get_unaligned_le32(p) * PRIME32_3; + h32 = xxh_rotl32(h32, 17) * PRIME32_4; + p += 4; + } + + while (p < b_end) { + h32 += (*p) * PRIME32_5; + h32 = xxh_rotl32(h32, 11) * PRIME32_1; + p++; + } + + h32 ^= h32 >> 15; + h32 *= PRIME32_2; + h32 ^= h32 >> 13; + h32 *= PRIME32_3; + h32 ^= h32 >> 16; + + return h32; +} +EXPORT_SYMBOL(xxh32_digest); + +int xxh64_update(struct xxh64_state *state, const void *input, const size_t len) +{ + const uint8_t *p = (const uint8_t *)input; + const uint8_t *const b_end = p + len; + + if (input == NULL) + return -EINVAL; + + state->total_len += len; + + if (state->memsize + len < 32) { /* fill in tmp buffer */ + memcpy(((uint8_t *)state->mem64) + state->memsize, input, len); + state->memsize += (uint32_t)len; + return 0; + } + + if (state->memsize) { /* tmp buffer is full */ + uint64_t *p64 = state->mem64; + + memcpy(((uint8_t *)p64) + state->memsize, input, + 32 - state->memsize); + + state->v1 = xxh64_round(state->v1, get_unaligned_le64(p64)); + p64++; + state->v2 = xxh64_round(state->v2, get_unaligned_le64(p64)); + p64++; + state->v3 = xxh64_round(state->v3, get_unaligned_le64(p64)); + p64++; + state->v4 = xxh64_round(state->v4, get_unaligned_le64(p64)); + + p += 32 - state->memsize; + state->memsize = 0; + } + + if (p + 32 <= b_end) { + const uint8_t *const limit = b_end - 32; + uint64_t v1 = state->v1; + uint64_t v2 = state->v2; + uint64_t v3 = state->v3; + uint64_t v4 = state->v4; + + do { + v1 = xxh64_round(v1, get_unaligned_le64(p)); + p += 8; + v2 = xxh64_round(v2, get_unaligned_le64(p)); + p += 8; + v3 = xxh64_round(v3, get_unaligned_le64(p)); + p += 8; + v4 = xxh64_round(v4, get_unaligned_le64(p)); + p += 8; + } while (p <= limit); + + state->v1 = v1; + state->v2 = v2; + state->v3 = v3; + state->v4 = v4; + } + + if (p < b_end) { + memcpy(state->mem64, p, (size_t)(b_end-p)); + state->memsize = (uint32_t)(b_end - p); + } + + return 0; +} +EXPORT_SYMBOL(xxh64_update); + +uint64_t xxh64_digest(const struct xxh64_state *state) +{ + const uint8_t *p = (const uint8_t *)state->mem64; + const uint8_t *const b_end = (const uint8_t *)state->mem64 + + state->memsize; + uint64_t h64; + + if (state->total_len >= 32) { + const uint64_t v1 = state->v1; + const uint64_t v2 = state->v2; + const uint64_t v3 = state->v3; + const uint64_t v4 = state->v4; + + h64 = xxh_rotl64(v1, 1) + xxh_rotl64(v2, 7) + + xxh_rotl64(v3, 12) + xxh_rotl64(v4, 18); + h64 = xxh64_merge_round(h64, v1); + h64 = xxh64_merge_round(h64, v2); + h64 = xxh64_merge_round(h64, v3); + h64 = xxh64_merge_round(h64, v4); + } else { + h64 = state->v3 + PRIME64_5; + } + + h64 += (uint64_t)state->total_len; + + while (p + 8 <= b_end) { + const uint64_t k1 = xxh64_round(0, get_unaligned_le64(p)); + + h64 ^= k1; + h64 = xxh_rotl64(h64, 27) * PRIME64_1 + PRIME64_4; + p += 8; + } + + if (p + 4 <= b_end) { + h64 ^= (uint64_t)(get_unaligned_le32(p)) * PRIME64_1; + h64 = xxh_rotl64(h64, 23) * PRIME64_2 + PRIME64_3; + p += 4; + } + + while (p < b_end) { + h64 ^= (*p) * PRIME64_5; + h64 = xxh_rotl64(h64, 11) * PRIME64_1; + p++; + } + + h64 ^= h64 >> 33; + h64 *= PRIME64_2; + h64 ^= h64 >> 29; + h64 *= PRIME64_3; + h64 ^= h64 >> 32; + + return h64; +} +EXPORT_SYMBOL(xxh64_digest); + +MODULE_LICENSE("Dual BSD/GPL"); +MODULE_DESCRIPTION("xxHash"); -- cgit v1.2.3 From 73f3d1b48f5069d46ba48aa28c2898dc93185560 Mon Sep 17 00:00:00 2001 From: Nick Terrell Date: Wed, 9 Aug 2017 19:35:53 -0700 Subject: lib: Add zstd modules Add zstd compression and decompression kernel modules. zstd offers a wide varity of compression speed and quality trade-offs. It can compress at speeds approaching lz4, and quality approaching lzma. zstd decompressions at speeds more than twice as fast as zlib, and decompression speed remains roughly the same across all compression levels. The code was ported from the upstream zstd source repository. The `linux/zstd.h` header was modified to match linux kernel style. The cross-platform and allocation code was stripped out. Instead zstd requires the caller to pass a preallocated workspace. The source files were clang-formatted [1] to match the Linux Kernel style as much as possible. Otherwise, the code was unmodified. We would like to avoid as much further manual modification to the source code as possible, so it will be easier to keep the kernel zstd up to date. I benchmarked zstd compression as a special character device. I ran zstd and zlib compression at several levels, as well as performing no compression, which measure the time spent copying the data to kernel space. Data is passed to the compresser 4096 B at a time. The benchmark file is located in the upstream zstd source repository under `contrib/linux-kernel/zstd_compress_test.c` [2]. I ran the benchmarks on a Ubuntu 14.04 VM with 2 cores and 4 GiB of RAM. The VM is running on a MacBook Pro with a 3.1 GHz Intel Core i7 processor, 16 GB of RAM, and a SSD. I benchmarked using `silesia.tar` [3], which is 211,988,480 B large. Run the following commands for the benchmark: sudo modprobe zstd_compress_test sudo mknod zstd_compress_test c 245 0 sudo cp silesia.tar zstd_compress_test The time is reported by the time of the userland `cp`. The MB/s is computed with 1,536,217,008 B / time(buffer size, hash) which includes the time to copy from userland. The Adjusted MB/s is computed with 1,536,217,088 B / (time(buffer size, hash) - time(buffer size, none)). The memory reported is the amount of memory the compressor requests. | Method | Size (B) | Time (s) | Ratio | MB/s | Adj MB/s | Mem (MB) | |----------|----------|----------|-------|---------|----------|----------| | none | 11988480 | 0.100 | 1 | 2119.88 | - | - | | zstd -1 | 73645762 | 1.044 | 2.878 | 203.05 | 224.56 | 1.23 | | zstd -3 | 66988878 | 1.761 | 3.165 | 120.38 | 127.63 | 2.47 | | zstd -5 | 65001259 | 2.563 | 3.261 | 82.71 | 86.07 | 2.86 | | zstd -10 | 60165346 | 13.242 | 3.523 | 16.01 | 16.13 | 13.22 | | zstd -15 | 58009756 | 47.601 | 3.654 | 4.45 | 4.46 | 21.61 | | zstd -19 | 54014593 | 102.835 | 3.925 | 2.06 | 2.06 | 60.15 | | zlib -1 | 77260026 | 2.895 | 2.744 | 73.23 | 75.85 | 0.27 | | zlib -3 | 72972206 | 4.116 | 2.905 | 51.50 | 52.79 | 0.27 | | zlib -6 | 68190360 | 9.633 | 3.109 | 22.01 | 22.24 | 0.27 | | zlib -9 | 67613382 | 22.554 | 3.135 | 9.40 | 9.44 | 0.27 | I benchmarked zstd decompression using the same method on the same machine. The benchmark file is located in the upstream zstd repo under `contrib/linux-kernel/zstd_decompress_test.c` [4]. The memory reported is the amount of memory required to decompress data compressed with the given compression level. If you know the maximum size of your input, you can reduce the memory usage of decompression irrespective of the compression level. | Method | Time (s) | MB/s | Adjusted MB/s | Memory (MB) | |----------|----------|---------|---------------|-------------| | none | 0.025 | 8479.54 | - | - | | zstd -1 | 0.358 | 592.15 | 636.60 | 0.84 | | zstd -3 | 0.396 | 535.32 | 571.40 | 1.46 | | zstd -5 | 0.396 | 535.32 | 571.40 | 1.46 | | zstd -10 | 0.374 | 566.81 | 607.42 | 2.51 | | zstd -15 | 0.379 | 559.34 | 598.84 | 4.61 | | zstd -19 | 0.412 | 514.54 | 547.77 | 8.80 | | zlib -1 | 0.940 | 225.52 | 231.68 | 0.04 | | zlib -3 | 0.883 | 240.08 | 247.07 | 0.04 | | zlib -6 | 0.844 | 251.17 | 258.84 | 0.04 | | zlib -9 | 0.837 | 253.27 | 287.64 | 0.04 | Tested in userland using the test-suite in the zstd repo under `contrib/linux-kernel/test/UserlandTest.cpp` [5] by mocking the kernel functions. Fuzz tested using libfuzzer [6] with the fuzz harnesses under `contrib/linux-kernel/test/{RoundTripCrash.c,DecompressCrash.c}` [7] [8] with ASAN, UBSAN, and MSAN. Additionaly, it was tested while testing the BtrFS and SquashFS patches coming next. [1] https://clang.llvm.org/docs/ClangFormat.html [2] https://github.com/facebook/zstd/blob/dev/contrib/linux-kernel/zstd_compress_test.c [3] http://sun.aei.polsl.pl/~sdeor/index.php?page=silesia [4] https://github.com/facebook/zstd/blob/dev/contrib/linux-kernel/zstd_decompress_test.c [5] https://github.com/facebook/zstd/blob/dev/contrib/linux-kernel/test/UserlandTest.cpp [6] http://llvm.org/docs/LibFuzzer.html [7] https://github.com/facebook/zstd/blob/dev/contrib/linux-kernel/test/RoundTripCrash.c [8] https://github.com/facebook/zstd/blob/dev/contrib/linux-kernel/test/DecompressCrash.c zstd source repository: https://github.com/facebook/zstd Signed-off-by: Nick Terrell Signed-off-by: Chris Mason --- include/linux/zstd.h | 1157 +++++++++++++++ lib/Kconfig | 8 + lib/Makefile | 2 + lib/zstd/Makefile | 18 + lib/zstd/bitstream.h | 374 +++++ lib/zstd/compress.c | 3484 +++++++++++++++++++++++++++++++++++++++++++++ lib/zstd/decompress.c | 2528 ++++++++++++++++++++++++++++++++ lib/zstd/entropy_common.c | 243 ++++ lib/zstd/error_private.h | 53 + lib/zstd/fse.h | 575 ++++++++ lib/zstd/fse_compress.c | 795 +++++++++++ lib/zstd/fse_decompress.c | 332 +++++ lib/zstd/huf.h | 212 +++ lib/zstd/huf_compress.c | 770 ++++++++++ lib/zstd/huf_decompress.c | 960 +++++++++++++ lib/zstd/mem.h | 151 ++ lib/zstd/zstd_common.c | 75 + lib/zstd/zstd_internal.h | 263 ++++ lib/zstd/zstd_opt.h | 1014 +++++++++++++ 19 files changed, 13014 insertions(+) create mode 100644 include/linux/zstd.h create mode 100644 lib/zstd/Makefile create mode 100644 lib/zstd/bitstream.h create mode 100644 lib/zstd/compress.c create mode 100644 lib/zstd/decompress.c create mode 100644 lib/zstd/entropy_common.c create mode 100644 lib/zstd/error_private.h create mode 100644 lib/zstd/fse.h create mode 100644 lib/zstd/fse_compress.c create mode 100644 lib/zstd/fse_decompress.c create mode 100644 lib/zstd/huf.h create mode 100644 lib/zstd/huf_compress.c create mode 100644 lib/zstd/huf_decompress.c create mode 100644 lib/zstd/mem.h create mode 100644 lib/zstd/zstd_common.c create mode 100644 lib/zstd/zstd_internal.h create mode 100644 lib/zstd/zstd_opt.h (limited to 'lib') diff --git a/include/linux/zstd.h b/include/linux/zstd.h new file mode 100644 index 000000000000..249575e2485f --- /dev/null +++ b/include/linux/zstd.h @@ -0,0 +1,1157 @@ +/* + * Copyright (c) 2016-present, Yann Collet, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of https://github.com/facebook/zstd. + * An additional grant of patent rights can be found in the PATENTS file in the + * same directory. + * + * This program is free software; you can redistribute it and/or modify it under + * the terms of the GNU General Public License version 2 as published by the + * Free Software Foundation. This program is dual-licensed; you may select + * either version 2 of the GNU General Public License ("GPL") or BSD license + * ("BSD"). + */ + +#ifndef ZSTD_H +#define ZSTD_H + +/* ====== Dependency ======*/ +#include /* size_t */ + + +/*-***************************************************************************** + * Introduction + * + * zstd, short for Zstandard, is a fast lossless compression algorithm, + * targeting real-time compression scenarios at zlib-level and better + * compression ratios. The zstd compression library provides in-memory + * compression and decompression functions. The library supports compression + * levels from 1 up to ZSTD_maxCLevel() which is 22. Levels >= 20, labeled + * ultra, should be used with caution, as they require more memory. + * Compression can be done in: + * - a single step, reusing a context (described as Explicit memory management) + * - unbounded multiple steps (described as Streaming compression) + * The compression ratio achievable on small data can be highly improved using + * compression with a dictionary in: + * - a single step (described as Simple dictionary API) + * - a single step, reusing a dictionary (described as Fast dictionary API) + ******************************************************************************/ + +/*====== Helper functions ======*/ + +/** + * enum ZSTD_ErrorCode - zstd error codes + * + * Functions that return size_t can be checked for errors using ZSTD_isError() + * and the ZSTD_ErrorCode can be extracted using ZSTD_getErrorCode(). + */ +typedef enum { + ZSTD_error_no_error, + ZSTD_error_GENERIC, + ZSTD_error_prefix_unknown, + ZSTD_error_version_unsupported, + ZSTD_error_parameter_unknown, + ZSTD_error_frameParameter_unsupported, + ZSTD_error_frameParameter_unsupportedBy32bits, + ZSTD_error_frameParameter_windowTooLarge, + ZSTD_error_compressionParameter_unsupported, + ZSTD_error_init_missing, + ZSTD_error_memory_allocation, + ZSTD_error_stage_wrong, + ZSTD_error_dstSize_tooSmall, + ZSTD_error_srcSize_wrong, + ZSTD_error_corruption_detected, + ZSTD_error_checksum_wrong, + ZSTD_error_tableLog_tooLarge, + ZSTD_error_maxSymbolValue_tooLarge, + ZSTD_error_maxSymbolValue_tooSmall, + ZSTD_error_dictionary_corrupted, + ZSTD_error_dictionary_wrong, + ZSTD_error_dictionaryCreation_failed, + ZSTD_error_maxCode +} ZSTD_ErrorCode; + +/** + * ZSTD_maxCLevel() - maximum compression level available + * + * Return: Maximum compression level available. + */ +int ZSTD_maxCLevel(void); +/** + * ZSTD_compressBound() - maximum compressed size in worst case scenario + * @srcSize: The size of the data to compress. + * + * Return: The maximum compressed size in the worst case scenario. + */ +size_t ZSTD_compressBound(size_t srcSize); +/** + * ZSTD_isError() - tells if a size_t function result is an error code + * @code: The function result to check for error. + * + * Return: Non-zero iff the code is an error. + */ +static __attribute__((unused)) unsigned int ZSTD_isError(size_t code) +{ + return code > (size_t)-ZSTD_error_maxCode; +} +/** + * ZSTD_getErrorCode() - translates an error function result to a ZSTD_ErrorCode + * @functionResult: The result of a function for which ZSTD_isError() is true. + * + * Return: The ZSTD_ErrorCode corresponding to the functionResult or 0 + * if the functionResult isn't an error. + */ +static __attribute__((unused)) ZSTD_ErrorCode ZSTD_getErrorCode( + size_t functionResult) +{ + if (!ZSTD_isError(functionResult)) + return (ZSTD_ErrorCode)0; + return (ZSTD_ErrorCode)(0 - functionResult); +} + +/** + * enum ZSTD_strategy - zstd compression search strategy + * + * From faster to stronger. + */ +typedef enum { + ZSTD_fast, + ZSTD_dfast, + ZSTD_greedy, + ZSTD_lazy, + ZSTD_lazy2, + ZSTD_btlazy2, + ZSTD_btopt, + ZSTD_btopt2 +} ZSTD_strategy; + +/** + * struct ZSTD_compressionParameters - zstd compression parameters + * @windowLog: Log of the largest match distance. Larger means more + * compression, and more memory needed during decompression. + * @chainLog: Fully searched segment. Larger means more compression, slower, + * and more memory (useless for fast). + * @hashLog: Dispatch table. Larger means more compression, + * slower, and more memory. + * @searchLog: Number of searches. Larger means more compression and slower. + * @searchLength: Match length searched. Larger means faster decompression, + * sometimes less compression. + * @targetLength: Acceptable match size for optimal parser (only). Larger means + * more compression, and slower. + * @strategy: The zstd compression strategy. + */ +typedef struct { + unsigned int windowLog; + unsigned int chainLog; + unsigned int hashLog; + unsigned int searchLog; + unsigned int searchLength; + unsigned int targetLength; + ZSTD_strategy strategy; +} ZSTD_compressionParameters; + +/** + * struct ZSTD_frameParameters - zstd frame parameters + * @contentSizeFlag: Controls whether content size will be present in the frame + * header (when known). + * @checksumFlag: Controls whether a 32-bit checksum is generated at the end + * of the frame for error detection. + * @noDictIDFlag: Controls whether dictID will be saved into the frame header + * when using dictionary compression. + * + * The default value is all fields set to 0. + */ +typedef struct { + unsigned int contentSizeFlag; + unsigned int checksumFlag; + unsigned int noDictIDFlag; +} ZSTD_frameParameters; + +/** + * struct ZSTD_parameters - zstd parameters + * @cParams: The compression parameters. + * @fParams: The frame parameters. + */ +typedef struct { + ZSTD_compressionParameters cParams; + ZSTD_frameParameters fParams; +} ZSTD_parameters; + +/** + * ZSTD_getCParams() - returns ZSTD_compressionParameters for selected level + * @compressionLevel: The compression level from 1 to ZSTD_maxCLevel(). + * @estimatedSrcSize: The estimated source size to compress or 0 if unknown. + * @dictSize: The dictionary size or 0 if a dictionary isn't being used. + * + * Return: The selected ZSTD_compressionParameters. + */ +ZSTD_compressionParameters ZSTD_getCParams(int compressionLevel, + unsigned long long estimatedSrcSize, size_t dictSize); + +/** + * ZSTD_getParams() - returns ZSTD_parameters for selected level + * @compressionLevel: The compression level from 1 to ZSTD_maxCLevel(). + * @estimatedSrcSize: The estimated source size to compress or 0 if unknown. + * @dictSize: The dictionary size or 0 if a dictionary isn't being used. + * + * The same as ZSTD_getCParams() except also selects the default frame + * parameters (all zero). + * + * Return: The selected ZSTD_parameters. + */ +ZSTD_parameters ZSTD_getParams(int compressionLevel, + unsigned long long estimatedSrcSize, size_t dictSize); + +/*-************************************* + * Explicit memory management + **************************************/ + +/** + * ZSTD_CCtxWorkspaceBound() - amount of memory needed to initialize a ZSTD_CCtx + * @cParams: The compression parameters to be used for compression. + * + * If multiple compression parameters might be used, the caller must call + * ZSTD_CCtxWorkspaceBound() for each set of parameters and use the maximum + * size. + * + * Return: A lower bound on the size of the workspace that is passed to + * ZSTD_initCCtx(). + */ +size_t ZSTD_CCtxWorkspaceBound(ZSTD_compressionParameters cParams); + +/** + * struct ZSTD_CCtx - the zstd compression context + * + * When compressing many times it is recommended to allocate a context just once + * and reuse it for each successive compression operation. + */ +typedef struct ZSTD_CCtx_s ZSTD_CCtx; +/** + * ZSTD_initCCtx() - initialize a zstd compression context + * @workspace: The workspace to emplace the context into. It must outlive + * the returned context. + * @workspaceSize: The size of workspace. Use ZSTD_CCtxWorkspaceBound() to + * determine how large the workspace must be. + * + * Return: A compression context emplaced into workspace. + */ +ZSTD_CCtx *ZSTD_initCCtx(void *workspace, size_t workspaceSize); + +/** + * ZSTD_compressCCtx() - compress src into dst + * @ctx: The context. Must have been initialized with a workspace at + * least as large as ZSTD_CCtxWorkspaceBound(params.cParams). + * @dst: The buffer to compress src into. + * @dstCapacity: The size of the destination buffer. May be any size, but + * ZSTD_compressBound(srcSize) is guaranteed to be large enough. + * @src: The data to compress. + * @srcSize: The size of the data to compress. + * @params: The parameters to use for compression. See ZSTD_getParams(). + * + * Return: The compressed size or an error, which can be checked using + * ZSTD_isError(). + */ +size_t ZSTD_compressCCtx(ZSTD_CCtx *ctx, void *dst, size_t dstCapacity, + const void *src, size_t srcSize, ZSTD_parameters params); + +/** + * ZSTD_DCtxWorkspaceBound() - amount of memory needed to initialize a ZSTD_DCtx + * + * Return: A lower bound on the size of the workspace that is passed to + * ZSTD_initDCtx(). + */ +size_t ZSTD_DCtxWorkspaceBound(void); + +/** + * struct ZSTD_DCtx - the zstd decompression context + * + * When decompressing many times it is recommended to allocate a context just + * once and reuse it for each successive decompression operation. + */ +typedef struct ZSTD_DCtx_s ZSTD_DCtx; +/** + * ZSTD_initDCtx() - initialize a zstd decompression context + * @workspace: The workspace to emplace the context into. It must outlive + * the returned context. + * @workspaceSize: The size of workspace. Use ZSTD_DCtxWorkspaceBound() to + * determine how large the workspace must be. + * + * Return: A decompression context emplaced into workspace. + */ +ZSTD_DCtx *ZSTD_initDCtx(void *workspace, size_t workspaceSize); + +/** + * ZSTD_decompressDCtx() - decompress zstd compressed src into dst + * @ctx: The decompression context. + * @dst: The buffer to decompress src into. + * @dstCapacity: The size of the destination buffer. Must be at least as large + * as the decompressed size. If the caller cannot upper bound the + * decompressed size, then it's better to use the streaming API. + * @src: The zstd compressed data to decompress. Multiple concatenated + * frames and skippable frames are allowed. + * @srcSize: The exact size of the data to decompress. + * + * Return: The decompressed size or an error, which can be checked using + * ZSTD_isError(). + */ +size_t ZSTD_decompressDCtx(ZSTD_DCtx *ctx, void *dst, size_t dstCapacity, + const void *src, size_t srcSize); + +/*-************************ + * Simple dictionary API + **************************/ + +/** + * ZSTD_compress_usingDict() - compress src into dst using a dictionary + * @ctx: The context. Must have been initialized with a workspace at + * least as large as ZSTD_CCtxWorkspaceBound(params.cParams). + * @dst: The buffer to compress src into. + * @dstCapacity: The size of the destination buffer. May be any size, but + * ZSTD_compressBound(srcSize) is guaranteed to be large enough. + * @src: The data to compress. + * @srcSize: The size of the data to compress. + * @dict: The dictionary to use for compression. + * @dictSize: The size of the dictionary. + * @params: The parameters to use for compression. See ZSTD_getParams(). + * + * Compression using a predefined dictionary. The same dictionary must be used + * during decompression. + * + * Return: The compressed size or an error, which can be checked using + * ZSTD_isError(). + */ +size_t ZSTD_compress_usingDict(ZSTD_CCtx *ctx, void *dst, size_t dstCapacity, + const void *src, size_t srcSize, const void *dict, size_t dictSize, + ZSTD_parameters params); + +/** + * ZSTD_decompress_usingDict() - decompress src into dst using a dictionary + * @ctx: The decompression context. + * @dst: The buffer to decompress src into. + * @dstCapacity: The size of the destination buffer. Must be at least as large + * as the decompressed size. If the caller cannot upper bound the + * decompressed size, then it's better to use the streaming API. + * @src: The zstd compressed data to decompress. Multiple concatenated + * frames and skippable frames are allowed. + * @srcSize: The exact size of the data to decompress. + * @dict: The dictionary to use for decompression. The same dictionary + * must've been used to compress the data. + * @dictSize: The size of the dictionary. + * + * Return: The decompressed size or an error, which can be checked using + * ZSTD_isError(). + */ +size_t ZSTD_decompress_usingDict(ZSTD_DCtx *ctx, void *dst, size_t dstCapacity, + const void *src, size_t srcSize, const void *dict, size_t dictSize); + +/*-************************** + * Fast dictionary API + ***************************/ + +/** + * ZSTD_CDictWorkspaceBound() - memory needed to initialize a ZSTD_CDict + * @cParams: The compression parameters to be used for compression. + * + * Return: A lower bound on the size of the workspace that is passed to + * ZSTD_initCDict(). + */ +size_t ZSTD_CDictWorkspaceBound(ZSTD_compressionParameters cParams); + +/** + * struct ZSTD_CDict - a digested dictionary to be used for compression + */ +typedef struct ZSTD_CDict_s ZSTD_CDict; + +/** + * ZSTD_initCDict() - initialize a digested dictionary for compression + * @dictBuffer: The dictionary to digest. The buffer is referenced by the + * ZSTD_CDict so it must outlive the returned ZSTD_CDict. + * @dictSize: The size of the dictionary. + * @params: The parameters to use for compression. See ZSTD_getParams(). + * @workspace: The workspace. It must outlive the returned ZSTD_CDict. + * @workspaceSize: The workspace size. Must be at least + * ZSTD_CDictWorkspaceBound(params.cParams). + * + * When compressing multiple messages / blocks with the same dictionary it is + * recommended to load it just once. The ZSTD_CDict merely references the + * dictBuffer, so it must outlive the returned ZSTD_CDict. + * + * Return: The digested dictionary emplaced into workspace. + */ +ZSTD_CDict *ZSTD_initCDict(const void *dictBuffer, size_t dictSize, + ZSTD_parameters params, void *workspace, size_t workspaceSize); + +/** + * ZSTD_compress_usingCDict() - compress src into dst using a ZSTD_CDict + * @ctx: The context. Must have been initialized with a workspace at + * least as large as ZSTD_CCtxWorkspaceBound(cParams) where + * cParams are the compression parameters used to initialize the + * cdict. + * @dst: The buffer to compress src into. + * @dstCapacity: The size of the destination buffer. May be any size, but + * ZSTD_compressBound(srcSize) is guaranteed to be large enough. + * @src: The data to compress. + * @srcSize: The size of the data to compress. + * @cdict: The digested dictionary to use for compression. + * @params: The parameters to use for compression. See ZSTD_getParams(). + * + * Compression using a digested dictionary. The same dictionary must be used + * during decompression. + * + * Return: The compressed size or an error, which can be checked using + * ZSTD_isError(). + */ +size_t ZSTD_compress_usingCDict(ZSTD_CCtx *cctx, void *dst, size_t dstCapacity, + const void *src, size_t srcSize, const ZSTD_CDict *cdict); + + +/** + * ZSTD_DDictWorkspaceBound() - memory needed to initialize a ZSTD_DDict + * + * Return: A lower bound on the size of the workspace that is passed to + * ZSTD_initDDict(). + */ +size_t ZSTD_DDictWorkspaceBound(void); + +/** + * struct ZSTD_DDict - a digested dictionary to be used for decompression + */ +typedef struct ZSTD_DDict_s ZSTD_DDict; + +/** + * ZSTD_initDDict() - initialize a digested dictionary for decompression + * @dictBuffer: The dictionary to digest. The buffer is referenced by the + * ZSTD_DDict so it must outlive the returned ZSTD_DDict. + * @dictSize: The size of the dictionary. + * @workspace: The workspace. It must outlive the returned ZSTD_DDict. + * @workspaceSize: The workspace size. Must be at least + * ZSTD_DDictWorkspaceBound(). + * + * When decompressing multiple messages / blocks with the same dictionary it is + * recommended to load it just once. The ZSTD_DDict merely references the + * dictBuffer, so it must outlive the returned ZSTD_DDict. + * + * Return: The digested dictionary emplaced into workspace. + */ +ZSTD_DDict *ZSTD_initDDict(const void *dictBuffer, size_t dictSize, + void *workspace, size_t workspaceSize); + +/** + * ZSTD_decompress_usingDDict() - decompress src into dst using a ZSTD_DDict + * @ctx: The decompression context. + * @dst: The buffer to decompress src into. + * @dstCapacity: The size of the destination buffer. Must be at least as large + * as the decompressed size. If the caller cannot upper bound the + * decompressed size, then it's better to use the streaming API. + * @src: The zstd compressed data to decompress. Multiple concatenated + * frames and skippable frames are allowed. + * @srcSize: The exact size of the data to decompress. + * @ddict: The digested dictionary to use for decompression. The same + * dictionary must've been used to compress the data. + * + * Return: The decompressed size or an error, which can be checked using + * ZSTD_isError(). + */ +size_t ZSTD_decompress_usingDDict(ZSTD_DCtx *dctx, void *dst, + size_t dstCapacity, const void *src, size_t srcSize, + const ZSTD_DDict *ddict); + + +/*-************************** + * Streaming + ***************************/ + +/** + * struct ZSTD_inBuffer - input buffer for streaming + * @src: Start of the input buffer. + * @size: Size of the input buffer. + * @pos: Position where reading stopped. Will be updated. + * Necessarily 0 <= pos <= size. + */ +typedef struct ZSTD_inBuffer_s { + const void *src; + size_t size; + size_t pos; +} ZSTD_inBuffer; + +/** + * struct ZSTD_outBuffer - output buffer for streaming + * @dst: Start of the output buffer. + * @size: Size of the output buffer. + * @pos: Position where writing stopped. Will be updated. + * Necessarily 0 <= pos <= size. + */ +typedef struct ZSTD_outBuffer_s { + void *dst; + size_t size; + size_t pos; +} ZSTD_outBuffer; + + + +/*-***************************************************************************** + * Streaming compression - HowTo + * + * A ZSTD_CStream object is required to track streaming operation. + * Use ZSTD_initCStream() to initialize a ZSTD_CStream object. + * ZSTD_CStream objects can be reused multiple times on consecutive compression + * operations. It is recommended to re-use ZSTD_CStream in situations where many + * streaming operations will be achieved consecutively. Use one separate + * ZSTD_CStream per thread for parallel execution. + * + * Use ZSTD_compressStream() repetitively to consume input stream. + * The function will automatically update both `pos` fields. + * Note that it may not consume the entire input, in which case `pos < size`, + * and it's up to the caller to present again remaining data. + * It returns a hint for the preferred number of bytes to use as an input for + * the next function call. + * + * At any moment, it's possible to flush whatever data remains within internal + * buffer, using ZSTD_flushStream(). `output->pos` will be updated. There might + * still be some content left within the internal buffer if `output->size` is + * too small. It returns the number of bytes left in the internal buffer and + * must be called until it returns 0. + * + * ZSTD_endStream() instructs to finish a frame. It will perform a flush and + * write frame epilogue. The epilogue is required for decoders to consider a + * frame completed. Similar to ZSTD_flushStream(), it may not be able to flush + * the full content if `output->size` is too small. In which case, call again + * ZSTD_endStream() to complete the flush. It returns the number of bytes left + * in the internal buffer and must be called until it returns 0. + ******************************************************************************/ + +/** + * ZSTD_CStreamWorkspaceBound() - memory needed to initialize a ZSTD_CStream + * @cParams: The compression parameters to be used for compression. + * + * Return: A lower bound on the size of the workspace that is passed to + * ZSTD_initCStream() and ZSTD_initCStream_usingCDict(). + */ +size_t ZSTD_CStreamWorkspaceBound(ZSTD_compressionParameters cParams); + +/** + * struct ZSTD_CStream - the zstd streaming compression context + */ +typedef struct ZSTD_CStream_s ZSTD_CStream; + +/*===== ZSTD_CStream management functions =====*/ +/** + * ZSTD_initCStream() - initialize a zstd streaming compression context + * @params: The zstd compression parameters. + * @pledgedSrcSize: If params.fParams.contentSizeFlag == 1 then the caller must + * pass the source size (zero means empty source). Otherwise, + * the caller may optionally pass the source size, or zero if + * unknown. + * @workspace: The workspace to emplace the context into. It must outlive + * the returned context. + * @workspaceSize: The size of workspace. + * Use ZSTD_CStreamWorkspaceBound(params.cParams) to determine + * how large the workspace must be. + * + * Return: The zstd streaming compression context. + */ +ZSTD_CStream *ZSTD_initCStream(ZSTD_parameters params, + unsigned long long pledgedSrcSize, void *workspace, + size_t workspaceSize); + +/** + * ZSTD_initCStream_usingCDict() - initialize a streaming compression context + * @cdict: The digested dictionary to use for compression. + * @pledgedSrcSize: Optionally the source size, or zero if unknown. + * @workspace: The workspace to emplace the context into. It must outlive + * the returned context. + * @workspaceSize: The size of workspace. Call ZSTD_CStreamWorkspaceBound() + * with the cParams used to initialize the cdict to determine + * how large the workspace must be. + * + * Return: The zstd streaming compression context. + */ +ZSTD_CStream *ZSTD_initCStream_usingCDict(const ZSTD_CDict *cdict, + unsigned long long pledgedSrcSize, void *workspace, + size_t workspaceSize); + +/*===== Streaming compression functions =====*/ +/** + * ZSTD_resetCStream() - reset the context using parameters from creation + * @zcs: The zstd streaming compression context to reset. + * @pledgedSrcSize: Optionally the source size, or zero if unknown. + * + * Resets the context using the parameters from creation. Skips dictionary + * loading, since it can be reused. If `pledgedSrcSize` is non-zero the frame + * content size is always written into the frame header. + * + * Return: Zero or an error, which can be checked using ZSTD_isError(). + */ +size_t ZSTD_resetCStream(ZSTD_CStream *zcs, unsigned long long pledgedSrcSize); +/** + * ZSTD_compressStream() - streaming compress some of input into output + * @zcs: The zstd streaming compression context. + * @output: Destination buffer. `output->pos` is updated to indicate how much + * compressed data was written. + * @input: Source buffer. `input->pos` is updated to indicate how much data was + * read. Note that it may not consume the entire input, in which case + * `input->pos < input->size`, and it's up to the caller to present + * remaining data again. + * + * The `input` and `output` buffers may be any size. Guaranteed to make some + * forward progress if `input` and `output` are not empty. + * + * Return: A hint for the number of bytes to use as the input for the next + * function call or an error, which can be checked using + * ZSTD_isError(). + */ +size_t ZSTD_compressStream(ZSTD_CStream *zcs, ZSTD_outBuffer *output, + ZSTD_inBuffer *input); +/** + * ZSTD_flushStream() - flush internal buffers into output + * @zcs: The zstd streaming compression context. + * @output: Destination buffer. `output->pos` is updated to indicate how much + * compressed data was written. + * + * ZSTD_flushStream() must be called until it returns 0, meaning all the data + * has been flushed. Since ZSTD_flushStream() causes a block to be ended, + * calling it too often will degrade the compression ratio. + * + * Return: The number of bytes still present within internal buffers or an + * error, which can be checked using ZSTD_isError(). + */ +size_t ZSTD_flushStream(ZSTD_CStream *zcs, ZSTD_outBuffer *output); +/** + * ZSTD_endStream() - flush internal buffers into output and end the frame + * @zcs: The zstd streaming compression context. + * @output: Destination buffer. `output->pos` is updated to indicate how much + * compressed data was written. + * + * ZSTD_endStream() must be called until it returns 0, meaning all the data has + * been flushed and the frame epilogue has been written. + * + * Return: The number of bytes still present within internal buffers or an + * error, which can be checked using ZSTD_isError(). + */ +size_t ZSTD_endStream(ZSTD_CStream *zcs, ZSTD_outBuffer *output); + +/** + * ZSTD_CStreamInSize() - recommended size for the input buffer + * + * Return: The recommended size for the input buffer. + */ +size_t ZSTD_CStreamInSize(void); +/** + * ZSTD_CStreamOutSize() - recommended size for the output buffer + * + * When the output buffer is at least this large, it is guaranteed to be large + * enough to flush at least one complete compressed block. + * + * Return: The recommended size for the output buffer. + */ +size_t ZSTD_CStreamOutSize(void); + + + +/*-***************************************************************************** + * Streaming decompression - HowTo + * + * A ZSTD_DStream object is required to track streaming operations. + * Use ZSTD_initDStream() to initialize a ZSTD_DStream object. + * ZSTD_DStream objects can be re-used multiple times. + * + * Use ZSTD_decompressStream() repetitively to consume your input. + * The function will update both `pos` fields. + * If `input->pos < input->size`, some input has not been consumed. + * It's up to the caller to present again remaining data. + * If `output->pos < output->size`, decoder has flushed everything it could. + * Returns 0 iff a frame is completely decoded and fully flushed. + * Otherwise it returns a suggested next input size that will never load more + * than the current frame. + ******************************************************************************/ + +/** + * ZSTD_DStreamWorkspaceBound() - memory needed to initialize a ZSTD_DStream + * @maxWindowSize: The maximum window size allowed for compressed frames. + * + * Return: A lower bound on the size of the workspace that is passed to + * ZSTD_initDStream() and ZSTD_initDStream_usingDDict(). + */ +size_t ZSTD_DStreamWorkspaceBound(size_t maxWindowSize); + +/** + * struct ZSTD_DStream - the zstd streaming decompression context + */ +typedef struct ZSTD_DStream_s ZSTD_DStream; +/*===== ZSTD_DStream management functions =====*/ +/** + * ZSTD_initDStream() - initialize a zstd streaming decompression context + * @maxWindowSize: The maximum window size allowed for compressed frames. + * @workspace: The workspace to emplace the context into. It must outlive + * the returned context. + * @workspaceSize: The size of workspace. + * Use ZSTD_DStreamWorkspaceBound(maxWindowSize) to determine + * how large the workspace must be. + * + * Return: The zstd streaming decompression context. + */ +ZSTD_DStream *ZSTD_initDStream(size_t maxWindowSize, void *workspace, + size_t workspaceSize); +/** + * ZSTD_initDStream_usingDDict() - initialize streaming decompression context + * @maxWindowSize: The maximum window size allowed for compressed frames. + * @ddict: The digested dictionary to use for decompression. + * @workspace: The workspace to emplace the context into. It must outlive + * the returned context. + * @workspaceSize: The size of workspace. + * Use ZSTD_DStreamWorkspaceBound(maxWindowSize) to determine + * how large the workspace must be. + * + * Return: The zstd streaming decompression context. + */ +ZSTD_DStream *ZSTD_initDStream_usingDDict(size_t maxWindowSize, + const ZSTD_DDict *ddict, void *workspace, size_t workspaceSize); + +/*===== Streaming decompression functions =====*/ +/** + * ZSTD_resetDStream() - reset the context using parameters from creation + * @zds: The zstd streaming decompression context to reset. + * + * Resets the context using the parameters from creation. Skips dictionary + * loading, since it can be reused. + * + * Return: Zero or an error, which can be checked using ZSTD_isError(). + */ +size_t ZSTD_resetDStream(ZSTD_DStream *zds); +/** + * ZSTD_decompressStream() - streaming decompress some of input into output + * @zds: The zstd streaming decompression context. + * @output: Destination buffer. `output.pos` is updated to indicate how much + * decompressed data was written. + * @input: Source buffer. `input.pos` is updated to indicate how much data was + * read. Note that it may not consume the entire input, in which case + * `input.pos < input.size`, and it's up to the caller to present + * remaining data again. + * + * The `input` and `output` buffers may be any size. Guaranteed to make some + * forward progress if `input` and `output` are not empty. + * ZSTD_decompressStream() will not consume the last byte of the frame until + * the entire frame is flushed. + * + * Return: Returns 0 iff a frame is completely decoded and fully flushed. + * Otherwise returns a hint for the number of bytes to use as the input + * for the next function call or an error, which can be checked using + * ZSTD_isError(). The size hint will never load more than the frame. + */ +size_t ZSTD_decompressStream(ZSTD_DStream *zds, ZSTD_outBuffer *output, + ZSTD_inBuffer *input); + +/** + * ZSTD_DStreamInSize() - recommended size for the input buffer + * + * Return: The recommended size for the input buffer. + */ +size_t ZSTD_DStreamInSize(void); +/** + * ZSTD_DStreamOutSize() - recommended size for the output buffer + * + * When the output buffer is at least this large, it is guaranteed to be large + * enough to flush at least one complete decompressed block. + * + * Return: The recommended size for the output buffer. + */ +size_t ZSTD_DStreamOutSize(void); + + +/* --- Constants ---*/ +#define ZSTD_MAGICNUMBER 0xFD2FB528 /* >= v0.8.0 */ +#define ZSTD_MAGIC_SKIPPABLE_START 0x184D2A50U + +#define ZSTD_CONTENTSIZE_UNKNOWN (0ULL - 1) +#define ZSTD_CONTENTSIZE_ERROR (0ULL - 2) + +#define ZSTD_WINDOWLOG_MAX_32 27 +#define ZSTD_WINDOWLOG_MAX_64 27 +#define ZSTD_WINDOWLOG_MAX \ + ((unsigned int)(sizeof(size_t) == 4 \ + ? ZSTD_WINDOWLOG_MAX_32 \ + : ZSTD_WINDOWLOG_MAX_64)) +#define ZSTD_WINDOWLOG_MIN 10 +#define ZSTD_HASHLOG_MAX ZSTD_WINDOWLOG_MAX +#define ZSTD_HASHLOG_MIN 6 +#define ZSTD_CHAINLOG_MAX (ZSTD_WINDOWLOG_MAX+1) +#define ZSTD_CHAINLOG_MIN ZSTD_HASHLOG_MIN +#define ZSTD_HASHLOG3_MAX 17 +#define ZSTD_SEARCHLOG_MAX (ZSTD_WINDOWLOG_MAX-1) +#define ZSTD_SEARCHLOG_MIN 1 +/* only for ZSTD_fast, other strategies are limited to 6 */ +#define ZSTD_SEARCHLENGTH_MAX 7 +/* only for ZSTD_btopt, other strategies are limited to 4 */ +#define ZSTD_SEARCHLENGTH_MIN 3 +#define ZSTD_TARGETLENGTH_MIN 4 +#define ZSTD_TARGETLENGTH_MAX 999 + +/* for static allocation */ +#define ZSTD_FRAMEHEADERSIZE_MAX 18 +#define ZSTD_FRAMEHEADERSIZE_MIN 6 +static const size_t ZSTD_frameHeaderSize_prefix = 5; +static const size_t ZSTD_frameHeaderSize_min = ZSTD_FRAMEHEADERSIZE_MIN; +static const size_t ZSTD_frameHeaderSize_max = ZSTD_FRAMEHEADERSIZE_MAX; +/* magic number + skippable frame length */ +static const size_t ZSTD_skippableHeaderSize = 8; + + +/*-************************************* + * Compressed size functions + **************************************/ + +/** + * ZSTD_findFrameCompressedSize() - returns the size of a compressed frame + * @src: Source buffer. It should point to the start of a zstd encoded frame + * or a skippable frame. + * @srcSize: The size of the source buffer. It must be at least as large as the + * size of the frame. + * + * Return: The compressed size of the frame pointed to by `src` or an error, + * which can be check with ZSTD_isError(). + * Suitable to pass to ZSTD_decompress() or similar functions. + */ +size_t ZSTD_findFrameCompressedSize(const void *src, size_t srcSize); + +/*-************************************* + * Decompressed size functions + **************************************/ +/** + * ZSTD_getFrameContentSize() - returns the content size in a zstd frame header + * @src: It should point to the start of a zstd encoded frame. + * @srcSize: The size of the source buffer. It must be at least as large as the + * frame header. `ZSTD_frameHeaderSize_max` is always large enough. + * + * Return: The frame content size stored in the frame header if known. + * `ZSTD_CONTENTSIZE_UNKNOWN` if the content size isn't stored in the + * frame header. `ZSTD_CONTENTSIZE_ERROR` on invalid input. + */ +unsigned long long ZSTD_getFrameContentSize(const void *src, size_t srcSize); + +/** + * ZSTD_findDecompressedSize() - returns decompressed size of a series of frames + * @src: It should point to the start of a series of zstd encoded and/or + * skippable frames. + * @srcSize: The exact size of the series of frames. + * + * If any zstd encoded frame in the series doesn't have the frame content size + * set, `ZSTD_CONTENTSIZE_UNKNOWN` is returned. But frame content size is always + * set when using ZSTD_compress(). The decompressed size can be very large. + * If the source is untrusted, the decompressed size could be wrong or + * intentionally modified. Always ensure the result fits within the + * application's authorized limits. ZSTD_findDecompressedSize() handles multiple + * frames, and so it must traverse the input to read each frame header. This is + * efficient as most of the data is skipped, however it does mean that all frame + * data must be present and valid. + * + * Return: Decompressed size of all the data contained in the frames if known. + * `ZSTD_CONTENTSIZE_UNKNOWN` if the decompressed size is unknown. + * `ZSTD_CONTENTSIZE_ERROR` if an error occurred. + */ +unsigned long long ZSTD_findDecompressedSize(const void *src, size_t srcSize); + +/*-************************************* + * Advanced compression functions + **************************************/ +/** + * ZSTD_checkCParams() - ensure parameter values remain within authorized range + * @cParams: The zstd compression parameters. + * + * Return: Zero or an error, which can be checked using ZSTD_isError(). + */ +size_t ZSTD_checkCParams(ZSTD_compressionParameters cParams); + +/** + * ZSTD_adjustCParams() - optimize parameters for a given srcSize and dictSize + * @srcSize: Optionally the estimated source size, or zero if unknown. + * @dictSize: Optionally the estimated dictionary size, or zero if unknown. + * + * Return: The optimized parameters. + */ +ZSTD_compressionParameters ZSTD_adjustCParams( + ZSTD_compressionParameters cParams, unsigned long long srcSize, + size_t dictSize); + +/*--- Advanced decompression functions ---*/ + +/** + * ZSTD_isFrame() - returns true iff the buffer starts with a valid frame + * @buffer: The source buffer to check. + * @size: The size of the source buffer, must be at least 4 bytes. + * + * Return: True iff the buffer starts with a zstd or skippable frame identifier. + */ +unsigned int ZSTD_isFrame(const void *buffer, size_t size); + +/** + * ZSTD_getDictID_fromDict() - returns the dictionary id stored in a dictionary + * @dict: The dictionary buffer. + * @dictSize: The size of the dictionary buffer. + * + * Return: The dictionary id stored within the dictionary or 0 if the + * dictionary is not a zstd dictionary. If it returns 0 the + * dictionary can still be loaded as a content-only dictionary. + */ +unsigned int ZSTD_getDictID_fromDict(const void *dict, size_t dictSize); + +/** + * ZSTD_getDictID_fromDDict() - returns the dictionary id stored in a ZSTD_DDict + * @ddict: The ddict to find the id of. + * + * Return: The dictionary id stored within `ddict` or 0 if the dictionary is not + * a zstd dictionary. If it returns 0 `ddict` will be loaded as a + * content-only dictionary. + */ +unsigned int ZSTD_getDictID_fromDDict(const ZSTD_DDict *ddict); + +/** + * ZSTD_getDictID_fromFrame() - returns the dictionary id stored in a zstd frame + * @src: Source buffer. It must be a zstd encoded frame. + * @srcSize: The size of the source buffer. It must be at least as large as the + * frame header. `ZSTD_frameHeaderSize_max` is always large enough. + * + * Return: The dictionary id required to decompress the frame stored within + * `src` or 0 if the dictionary id could not be decoded. It can return + * 0 if the frame does not require a dictionary, the dictionary id + * wasn't stored in the frame, `src` is not a zstd frame, or `srcSize` + * is too small. + */ +unsigned int ZSTD_getDictID_fromFrame(const void *src, size_t srcSize); + +/** + * struct ZSTD_frameParams - zstd frame parameters stored in the frame header + * @frameContentSize: The frame content size, or 0 if not present. + * @windowSize: The window size, or 0 if the frame is a skippable frame. + * @dictID: The dictionary id, or 0 if not present. + * @checksumFlag: Whether a checksum was used. + */ +typedef struct { + unsigned long long frameContentSize; + unsigned int windowSize; + unsigned int dictID; + unsigned int checksumFlag; +} ZSTD_frameParams; + +/** + * ZSTD_getFrameParams() - extracts parameters from a zstd or skippable frame + * @fparamsPtr: On success the frame parameters are written here. + * @src: The source buffer. It must point to a zstd or skippable frame. + * @srcSize: The size of the source buffer. `ZSTD_frameHeaderSize_max` is + * always large enough to succeed. + * + * Return: 0 on success. If more data is required it returns how many bytes + * must be provided to make forward progress. Otherwise it returns + * an error, which can be checked using ZSTD_isError(). + */ +size_t ZSTD_getFrameParams(ZSTD_frameParams *fparamsPtr, const void *src, + size_t srcSize); + +/*-***************************************************************************** + * Buffer-less and synchronous inner streaming functions + * + * This is an advanced API, giving full control over buffer management, for + * users which need direct control over memory. + * But it's also a complex one, with many restrictions (documented below). + * Prefer using normal streaming API for an easier experience + ******************************************************************************/ + +/*-***************************************************************************** + * Buffer-less streaming compression (synchronous mode) + * + * A ZSTD_CCtx object is required to track streaming operations. + * Use ZSTD_initCCtx() to initialize a context. + * ZSTD_CCtx object can be re-used multiple times within successive compression + * operations. + * + * Start by initializing a context. + * Use ZSTD_compressBegin(), or ZSTD_compressBegin_usingDict() for dictionary + * compression, + * or ZSTD_compressBegin_advanced(), for finer parameter control. + * It's also possible to duplicate a reference context which has already been + * initialized, using ZSTD_copyCCtx() + * + * Then, consume your input using ZSTD_compressContinue(). + * There are some important considerations to keep in mind when using this + * advanced function : + * - ZSTD_compressContinue() has no internal buffer. It uses externally provided + * buffer only. + * - Interface is synchronous : input is consumed entirely and produce 1+ + * (or more) compressed blocks. + * - Caller must ensure there is enough space in `dst` to store compressed data + * under worst case scenario. Worst case evaluation is provided by + * ZSTD_compressBound(). + * ZSTD_compressContinue() doesn't guarantee recover after a failed + * compression. + * - ZSTD_compressContinue() presumes prior input ***is still accessible and + * unmodified*** (up to maximum distance size, see WindowLog). + * It remembers all previous contiguous blocks, plus one separated memory + * segment (which can itself consists of multiple contiguous blocks) + * - ZSTD_compressContinue() detects that prior input has been overwritten when + * `src` buffer overlaps. In which case, it will "discard" the relevant memory + * section from its history. + * + * Finish a frame with ZSTD_compressEnd(), which will write the last block(s) + * and optional checksum. It's possible to use srcSize==0, in which case, it + * will write a final empty block to end the frame. Without last block mark, + * frames will be considered unfinished (corrupted) by decoders. + * + * `ZSTD_CCtx` object can be re-used (ZSTD_compressBegin()) to compress some new + * frame. + ******************************************************************************/ + +/*===== Buffer-less streaming compression functions =====*/ +size_t ZSTD_compressBegin(ZSTD_CCtx *cctx, int compressionLevel); +size_t ZSTD_compressBegin_usingDict(ZSTD_CCtx *cctx, const void *dict, + size_t dictSize, int compressionLevel); +size_t ZSTD_compressBegin_advanced(ZSTD_CCtx *cctx, const void *dict, + size_t dictSize, ZSTD_parameters params, + unsigned long long pledgedSrcSize); +size_t ZSTD_copyCCtx(ZSTD_CCtx *cctx, const ZSTD_CCtx *preparedCCtx, + unsigned long long pledgedSrcSize); +size_t ZSTD_compressBegin_usingCDict(ZSTD_CCtx *cctx, const ZSTD_CDict *cdict, + unsigned long long pledgedSrcSize); +size_t ZSTD_compressContinue(ZSTD_CCtx *cctx, void *dst, size_t dstCapacity, + const void *src, size_t srcSize); +size_t ZSTD_compressEnd(ZSTD_CCtx *cctx, void *dst, size_t dstCapacity, + const void *src, size_t srcSize); + + + +/*-***************************************************************************** + * Buffer-less streaming decompression (synchronous mode) + * + * A ZSTD_DCtx object is required to track streaming operations. + * Use ZSTD_initDCtx() to initialize a context. + * A ZSTD_DCtx object can be re-used multiple times. + * + * First typical operation is to retrieve frame parameters, using + * ZSTD_getFrameParams(). It fills a ZSTD_frameParams structure which provide + * important information to correctly decode the frame, such as the minimum + * rolling buffer size to allocate to decompress data (`windowSize`), and the + * dictionary ID used. + * Note: content size is optional, it may not be present. 0 means unknown. + * Note that these values could be wrong, either because of data malformation, + * or because an attacker is spoofing deliberate false information. As a + * consequence, check that values remain within valid application range, + * especially `windowSize`, before allocation. Each application can set its own + * limit, depending on local restrictions. For extended interoperability, it is + * recommended to support at least 8 MB. + * Frame parameters are extracted from the beginning of the compressed frame. + * Data fragment must be large enough to ensure successful decoding, typically + * `ZSTD_frameHeaderSize_max` bytes. + * Result: 0: successful decoding, the `ZSTD_frameParams` structure is filled. + * >0: `srcSize` is too small, provide at least this many bytes. + * errorCode, which can be tested using ZSTD_isError(). + * + * Start decompression, with ZSTD_decompressBegin() or + * ZSTD_decompressBegin_usingDict(). Alternatively, you can copy a prepared + * context, using ZSTD_copyDCtx(). + * + * Then use ZSTD_nextSrcSizeToDecompress() and ZSTD_decompressContinue() + * alternatively. + * ZSTD_nextSrcSizeToDecompress() tells how many bytes to provide as 'srcSize' + * to ZSTD_decompressContinue(). + * ZSTD_decompressContinue() requires this _exact_ amount of bytes, or it will + * fail. + * + * The result of ZSTD_decompressContinue() is the number of bytes regenerated + * within 'dst' (necessarily <= dstCapacity). It can be zero, which is not an + * error; it just means ZSTD_decompressContinue() has decoded some metadata + * item. It can also be an error code, which can be tested with ZSTD_isError(). + * + * ZSTD_decompressContinue() needs previous data blocks during decompression, up + * to `windowSize`. They should preferably be located contiguously, prior to + * current block. Alternatively, a round buffer of sufficient size is also + * possible. Sufficient size is determined by frame parameters. + * ZSTD_decompressContinue() is very sensitive to contiguity, if 2 blocks don't + * follow each other, make sure that either the compressor breaks contiguity at + * the same place, or that previous contiguous segment is large enough to + * properly handle maximum back-reference. + * + * A frame is fully decoded when ZSTD_nextSrcSizeToDecompress() returns zero. + * Context can then be reset to start a new decompression. + * + * Note: it's possible to know if next input to present is a header or a block, + * using ZSTD_nextInputType(). This information is not required to properly + * decode a frame. + * + * == Special case: skippable frames == + * + * Skippable frames allow integration of user-defined data into a flow of + * concatenated frames. Skippable frames will be ignored (skipped) by a + * decompressor. The format of skippable frames is as follows: + * a) Skippable frame ID - 4 Bytes, Little endian format, any value from + * 0x184D2A50 to 0x184D2A5F + * b) Frame Size - 4 Bytes, Little endian format, unsigned 32-bits + * c) Frame Content - any content (User Data) of length equal to Frame Size + * For skippable frames ZSTD_decompressContinue() always returns 0. + * For skippable frames ZSTD_getFrameParams() returns fparamsPtr->windowLog==0 + * what means that a frame is skippable. + * Note: If fparamsPtr->frameContentSize==0, it is ambiguous: the frame might + * actually be a zstd encoded frame with no content. For purposes of + * decompression, it is valid in both cases to skip the frame using + * ZSTD_findFrameCompressedSize() to find its size in bytes. + * It also returns frame size as fparamsPtr->frameContentSize. + ******************************************************************************/ + +/*===== Buffer-less streaming decompression functions =====*/ +size_t ZSTD_decompressBegin(ZSTD_DCtx *dctx); +size_t ZSTD_decompressBegin_usingDict(ZSTD_DCtx *dctx, const void *dict, + size_t dictSize); +void ZSTD_copyDCtx(ZSTD_DCtx *dctx, const ZSTD_DCtx *preparedDCtx); +size_t ZSTD_nextSrcSizeToDecompress(ZSTD_DCtx *dctx); +size_t ZSTD_decompressContinue(ZSTD_DCtx *dctx, void *dst, size_t dstCapacity, + const void *src, size_t srcSize); +typedef enum { + ZSTDnit_frameHeader, + ZSTDnit_blockHeader, + ZSTDnit_block, + ZSTDnit_lastBlock, + ZSTDnit_checksum, + ZSTDnit_skippableFrame +} ZSTD_nextInputType_e; +ZSTD_nextInputType_e ZSTD_nextInputType(ZSTD_DCtx *dctx); + +/*-***************************************************************************** + * Block functions + * + * Block functions produce and decode raw zstd blocks, without frame metadata. + * Frame metadata cost is typically ~18 bytes, which can be non-negligible for + * very small blocks (< 100 bytes). User will have to take in charge required + * information to regenerate data, such as compressed and content sizes. + * + * A few rules to respect: + * - Compressing and decompressing require a context structure + * + Use ZSTD_initCCtx() and ZSTD_initDCtx() + * - It is necessary to init context before starting + * + compression : ZSTD_compressBegin() + * + decompression : ZSTD_decompressBegin() + * + variants _usingDict() are also allowed + * + copyCCtx() and copyDCtx() work too + * - Block size is limited, it must be <= ZSTD_getBlockSizeMax() + * + If you need to compress more, cut data into multiple blocks + * + Consider using the regular ZSTD_compress() instead, as frame metadata + * costs become negligible when source size is large. + * - When a block is considered not compressible enough, ZSTD_compressBlock() + * result will be zero. In which case, nothing is produced into `dst`. + * + User must test for such outcome and deal directly with uncompressed data + * + ZSTD_decompressBlock() doesn't accept uncompressed data as input!!! + * + In case of multiple successive blocks, decoder must be informed of + * uncompressed block existence to follow proper history. Use + * ZSTD_insertBlock() in such a case. + ******************************************************************************/ + +/* Define for static allocation */ +#define ZSTD_BLOCKSIZE_ABSOLUTEMAX (128 * 1024) +/*===== Raw zstd block functions =====*/ +size_t ZSTD_getBlockSizeMax(ZSTD_CCtx *cctx); +size_t ZSTD_compressBlock(ZSTD_CCtx *cctx, void *dst, size_t dstCapacity, + const void *src, size_t srcSize); +size_t ZSTD_decompressBlock(ZSTD_DCtx *dctx, void *dst, size_t dstCapacity, + const void *src, size_t srcSize); +size_t ZSTD_insertBlock(ZSTD_DCtx *dctx, const void *blockStart, + size_t blockSize); + +#endif /* ZSTD_H */ diff --git a/lib/Kconfig b/lib/Kconfig index 5e7541f46081..0d49ed0e6253 100644 --- a/lib/Kconfig +++ b/lib/Kconfig @@ -249,6 +249,14 @@ config LZ4HC_COMPRESS config LZ4_DECOMPRESS tristate +config ZSTD_COMPRESS + select XXHASH + tristate + +config ZSTD_DECOMPRESS + select XXHASH + tristate + source "lib/xz/Kconfig" # diff --git a/lib/Makefile b/lib/Makefile index d06b68ae55a3..d5c8a4ffff80 100644 --- a/lib/Makefile +++ b/lib/Makefile @@ -116,6 +116,8 @@ obj-$(CONFIG_LZO_DECOMPRESS) += lzo/ obj-$(CONFIG_LZ4_COMPRESS) += lz4/ obj-$(CONFIG_LZ4HC_COMPRESS) += lz4/ obj-$(CONFIG_LZ4_DECOMPRESS) += lz4/ +obj-$(CONFIG_ZSTD_COMPRESS) += zstd/ +obj-$(CONFIG_ZSTD_DECOMPRESS) += zstd/ obj-$(CONFIG_XZ_DEC) += xz/ obj-$(CONFIG_RAID6_PQ) += raid6/ diff --git a/lib/zstd/Makefile b/lib/zstd/Makefile new file mode 100644 index 000000000000..dd0a359c135b --- /dev/null +++ b/lib/zstd/Makefile @@ -0,0 +1,18 @@ +obj-$(CONFIG_ZSTD_COMPRESS) += zstd_compress.o +obj-$(CONFIG_ZSTD_DECOMPRESS) += zstd_decompress.o + +ccflags-y += -O3 + +# Object files unique to zstd_compress and zstd_decompress +zstd_compress-y := fse_compress.o huf_compress.o compress.o +zstd_decompress-y := huf_decompress.o decompress.o + +# These object files are shared between the modules. +# Always add them to zstd_compress. +# Unless both zstd_compress and zstd_decompress are built in +# then also add them to zstd_decompress. +zstd_compress-y += entropy_common.o fse_decompress.o zstd_common.o + +ifneq ($(CONFIG_ZSTD_COMPRESS)$(CONFIG_ZSTD_DECOMPRESS),yy) + zstd_decompress-y += entropy_common.o fse_decompress.o zstd_common.o +endif diff --git a/lib/zstd/bitstream.h b/lib/zstd/bitstream.h new file mode 100644 index 000000000000..a826b99e1d63 --- /dev/null +++ b/lib/zstd/bitstream.h @@ -0,0 +1,374 @@ +/* + * bitstream + * Part of FSE library + * header file (to include) + * Copyright (C) 2013-2016, Yann Collet. + * + * BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php) + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * This program is free software; you can redistribute it and/or modify it under + * the terms of the GNU General Public License version 2 as published by the + * Free Software Foundation. This program is dual-licensed; you may select + * either version 2 of the GNU General Public License ("GPL") or BSD license + * ("BSD"). + * + * You can contact the author at : + * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy + */ +#ifndef BITSTREAM_H_MODULE +#define BITSTREAM_H_MODULE + +/* +* This API consists of small unitary functions, which must be inlined for best performance. +* Since link-time-optimization is not available for all compilers, +* these functions are defined into a .h to be included. +*/ + +/*-**************************************** +* Dependencies +******************************************/ +#include "error_private.h" /* error codes and messages */ +#include "mem.h" /* unaligned access routines */ + +/*========================================= +* Target specific +=========================================*/ +#define STREAM_ACCUMULATOR_MIN_32 25 +#define STREAM_ACCUMULATOR_MIN_64 57 +#define STREAM_ACCUMULATOR_MIN ((U32)(ZSTD_32bits() ? STREAM_ACCUMULATOR_MIN_32 : STREAM_ACCUMULATOR_MIN_64)) + +/*-****************************************** +* bitStream encoding API (write forward) +********************************************/ +/* bitStream can mix input from multiple sources. +* A critical property of these streams is that they encode and decode in **reverse** direction. +* So the first bit sequence you add will be the last to be read, like a LIFO stack. +*/ +typedef struct { + size_t bitContainer; + int bitPos; + char *startPtr; + char *ptr; + char *endPtr; +} BIT_CStream_t; + +ZSTD_STATIC size_t BIT_initCStream(BIT_CStream_t *bitC, void *dstBuffer, size_t dstCapacity); +ZSTD_STATIC void BIT_addBits(BIT_CStream_t *bitC, size_t value, unsigned nbBits); +ZSTD_STATIC void BIT_flushBits(BIT_CStream_t *bitC); +ZSTD_STATIC size_t BIT_closeCStream(BIT_CStream_t *bitC); + +/* Start with initCStream, providing the size of buffer to write into. +* bitStream will never write outside of this buffer. +* `dstCapacity` must be >= sizeof(bitD->bitContainer), otherwise @return will be an error code. +* +* bits are first added to a local register. +* Local register is size_t, hence 64-bits on 64-bits systems, or 32-bits on 32-bits systems. +* Writing data into memory is an explicit operation, performed by the flushBits function. +* Hence keep track how many bits are potentially stored into local register to avoid register overflow. +* After a flushBits, a maximum of 7 bits might still be stored into local register. +* +* Avoid storing elements of more than 24 bits if you want compatibility with 32-bits bitstream readers. +* +* Last operation is to close the bitStream. +* The function returns the final size of CStream in bytes. +* If data couldn't fit into `dstBuffer`, it will return a 0 ( == not storable) +*/ + +/*-******************************************** +* bitStream decoding API (read backward) +**********************************************/ +typedef struct { + size_t bitContainer; + unsigned bitsConsumed; + const char *ptr; + const char *start; +} BIT_DStream_t; + +typedef enum { + BIT_DStream_unfinished = 0, + BIT_DStream_endOfBuffer = 1, + BIT_DStream_completed = 2, + BIT_DStream_overflow = 3 +} BIT_DStream_status; /* result of BIT_reloadDStream() */ +/* 1,2,4,8 would be better for bitmap combinations, but slows down performance a bit ... :( */ + +ZSTD_STATIC size_t BIT_initDStream(BIT_DStream_t *bitD, const void *srcBuffer, size_t srcSize); +ZSTD_STATIC size_t BIT_readBits(BIT_DStream_t *bitD, unsigned nbBits); +ZSTD_STATIC BIT_DStream_status BIT_reloadDStream(BIT_DStream_t *bitD); +ZSTD_STATIC unsigned BIT_endOfDStream(const BIT_DStream_t *bitD); + +/* Start by invoking BIT_initDStream(). +* A chunk of the bitStream is then stored into a local register. +* Local register size is 64-bits on 64-bits systems, 32-bits on 32-bits systems (size_t). +* You can then retrieve bitFields stored into the local register, **in reverse order**. +* Local register is explicitly reloaded from memory by the BIT_reloadDStream() method. +* A reload guarantee a minimum of ((8*sizeof(bitD->bitContainer))-7) bits when its result is BIT_DStream_unfinished. +* Otherwise, it can be less than that, so proceed accordingly. +* Checking if DStream has reached its end can be performed with BIT_endOfDStream(). +*/ + +/*-**************************************** +* unsafe API +******************************************/ +ZSTD_STATIC void BIT_addBitsFast(BIT_CStream_t *bitC, size_t value, unsigned nbBits); +/* faster, but works only if value is "clean", meaning all high bits above nbBits are 0 */ + +ZSTD_STATIC void BIT_flushBitsFast(BIT_CStream_t *bitC); +/* unsafe version; does not check buffer overflow */ + +ZSTD_STATIC size_t BIT_readBitsFast(BIT_DStream_t *bitD, unsigned nbBits); +/* faster, but works only if nbBits >= 1 */ + +/*-************************************************************** +* Internal functions +****************************************************************/ +ZSTD_STATIC unsigned BIT_highbit32(register U32 val) { return 31 - __builtin_clz(val); } + +/*===== Local Constants =====*/ +static const unsigned BIT_mask[] = {0, 1, 3, 7, 0xF, 0x1F, 0x3F, 0x7F, 0xFF, + 0x1FF, 0x3FF, 0x7FF, 0xFFF, 0x1FFF, 0x3FFF, 0x7FFF, 0xFFFF, 0x1FFFF, + 0x3FFFF, 0x7FFFF, 0xFFFFF, 0x1FFFFF, 0x3FFFFF, 0x7FFFFF, 0xFFFFFF, 0x1FFFFFF, 0x3FFFFFF}; /* up to 26 bits */ + +/*-************************************************************** +* bitStream encoding +****************************************************************/ +/*! BIT_initCStream() : + * `dstCapacity` must be > sizeof(void*) + * @return : 0 if success, + otherwise an error code (can be tested using ERR_isError() ) */ +ZSTD_STATIC size_t BIT_initCStream(BIT_CStream_t *bitC, void *startPtr, size_t dstCapacity) +{ + bitC->bitContainer = 0; + bitC->bitPos = 0; + bitC->startPtr = (char *)startPtr; + bitC->ptr = bitC->startPtr; + bitC->endPtr = bitC->startPtr + dstCapacity - sizeof(bitC->ptr); + if (dstCapacity <= sizeof(bitC->ptr)) + return ERROR(dstSize_tooSmall); + return 0; +} + +/*! BIT_addBits() : + can add up to 26 bits into `bitC`. + Does not check for register overflow ! */ +ZSTD_STATIC void BIT_addBits(BIT_CStream_t *bitC, size_t value, unsigned nbBits) +{ + bitC->bitContainer |= (value & BIT_mask[nbBits]) << bitC->bitPos; + bitC->bitPos += nbBits; +} + +/*! BIT_addBitsFast() : + * works only if `value` is _clean_, meaning all high bits above nbBits are 0 */ +ZSTD_STATIC void BIT_addBitsFast(BIT_CStream_t *bitC, size_t value, unsigned nbBits) +{ + bitC->bitContainer |= value << bitC->bitPos; + bitC->bitPos += nbBits; +} + +/*! BIT_flushBitsFast() : + * unsafe version; does not check buffer overflow */ +ZSTD_STATIC void BIT_flushBitsFast(BIT_CStream_t *bitC) +{ + size_t const nbBytes = bitC->bitPos >> 3; + ZSTD_writeLEST(bitC->ptr, bitC->bitContainer); + bitC->ptr += nbBytes; + bitC->bitPos &= 7; + bitC->bitContainer >>= nbBytes * 8; /* if bitPos >= sizeof(bitContainer)*8 --> undefined behavior */ +} + +/*! BIT_flushBits() : + * safe version; check for buffer overflow, and prevents it. + * note : does not signal buffer overflow. This will be revealed later on using BIT_closeCStream() */ +ZSTD_STATIC void BIT_flushBits(BIT_CStream_t *bitC) +{ + size_t const nbBytes = bitC->bitPos >> 3; + ZSTD_writeLEST(bitC->ptr, bitC->bitContainer); + bitC->ptr += nbBytes; + if (bitC->ptr > bitC->endPtr) + bitC->ptr = bitC->endPtr; + bitC->bitPos &= 7; + bitC->bitContainer >>= nbBytes * 8; /* if bitPos >= sizeof(bitContainer)*8 --> undefined behavior */ +} + +/*! BIT_closeCStream() : + * @return : size of CStream, in bytes, + or 0 if it could not fit into dstBuffer */ +ZSTD_STATIC size_t BIT_closeCStream(BIT_CStream_t *bitC) +{ + BIT_addBitsFast(bitC, 1, 1); /* endMark */ + BIT_flushBits(bitC); + + if (bitC->ptr >= bitC->endPtr) + return 0; /* doesn't fit within authorized budget : cancel */ + + return (bitC->ptr - bitC->startPtr) + (bitC->bitPos > 0); +} + +/*-******************************************************** +* bitStream decoding +**********************************************************/ +/*! BIT_initDStream() : +* Initialize a BIT_DStream_t. +* `bitD` : a pointer to an already allocated BIT_DStream_t structure. +* `srcSize` must be the *exact* size of the bitStream, in bytes. +* @return : size of stream (== srcSize) or an errorCode if a problem is detected +*/ +ZSTD_STATIC size_t BIT_initDStream(BIT_DStream_t *bitD, const void *srcBuffer, size_t srcSize) +{ + if (srcSize < 1) { + memset(bitD, 0, sizeof(*bitD)); + return ERROR(srcSize_wrong); + } + + if (srcSize >= sizeof(bitD->bitContainer)) { /* normal case */ + bitD->start = (const char *)srcBuffer; + bitD->ptr = (const char *)srcBuffer + srcSize - sizeof(bitD->bitContainer); + bitD->bitContainer = ZSTD_readLEST(bitD->ptr); + { + BYTE const lastByte = ((const BYTE *)srcBuffer)[srcSize - 1]; + bitD->bitsConsumed = lastByte ? 8 - BIT_highbit32(lastByte) : 0; /* ensures bitsConsumed is always set */ + if (lastByte == 0) + return ERROR(GENERIC); /* endMark not present */ + } + } else { + bitD->start = (const char *)srcBuffer; + bitD->ptr = bitD->start; + bitD->bitContainer = *(const BYTE *)(bitD->start); + switch (srcSize) { + case 7: bitD->bitContainer += (size_t)(((const BYTE *)(srcBuffer))[6]) << (sizeof(bitD->bitContainer) * 8 - 16); + case 6: bitD->bitContainer += (size_t)(((const BYTE *)(srcBuffer))[5]) << (sizeof(bitD->bitContainer) * 8 - 24); + case 5: bitD->bitContainer += (size_t)(((const BYTE *)(srcBuffer))[4]) << (sizeof(bitD->bitContainer) * 8 - 32); + case 4: bitD->bitContainer += (size_t)(((const BYTE *)(srcBuffer))[3]) << 24; + case 3: bitD->bitContainer += (size_t)(((const BYTE *)(srcBuffer))[2]) << 16; + case 2: bitD->bitContainer += (size_t)(((const BYTE *)(srcBuffer))[1]) << 8; + default:; + } + { + BYTE const lastByte = ((const BYTE *)srcBuffer)[srcSize - 1]; + bitD->bitsConsumed = lastByte ? 8 - BIT_highbit32(lastByte) : 0; + if (lastByte == 0) + return ERROR(GENERIC); /* endMark not present */ + } + bitD->bitsConsumed += (U32)(sizeof(bitD->bitContainer) - srcSize) * 8; + } + + return srcSize; +} + +ZSTD_STATIC size_t BIT_getUpperBits(size_t bitContainer, U32 const start) { return bitContainer >> start; } + +ZSTD_STATIC size_t BIT_getMiddleBits(size_t bitContainer, U32 const start, U32 const nbBits) { return (bitContainer >> start) & BIT_mask[nbBits]; } + +ZSTD_STATIC size_t BIT_getLowerBits(size_t bitContainer, U32 const nbBits) { return bitContainer & BIT_mask[nbBits]; } + +/*! BIT_lookBits() : + * Provides next n bits from local register. + * local register is not modified. + * On 32-bits, maxNbBits==24. + * On 64-bits, maxNbBits==56. + * @return : value extracted + */ +ZSTD_STATIC size_t BIT_lookBits(const BIT_DStream_t *bitD, U32 nbBits) +{ + U32 const bitMask = sizeof(bitD->bitContainer) * 8 - 1; + return ((bitD->bitContainer << (bitD->bitsConsumed & bitMask)) >> 1) >> ((bitMask - nbBits) & bitMask); +} + +/*! BIT_lookBitsFast() : +* unsafe version; only works only if nbBits >= 1 */ +ZSTD_STATIC size_t BIT_lookBitsFast(const BIT_DStream_t *bitD, U32 nbBits) +{ + U32 const bitMask = sizeof(bitD->bitContainer) * 8 - 1; + return (bitD->bitContainer << (bitD->bitsConsumed & bitMask)) >> (((bitMask + 1) - nbBits) & bitMask); +} + +ZSTD_STATIC void BIT_skipBits(BIT_DStream_t *bitD, U32 nbBits) { bitD->bitsConsumed += nbBits; } + +/*! BIT_readBits() : + * Read (consume) next n bits from local register and update. + * Pay attention to not read more than nbBits contained into local register. + * @return : extracted value. + */ +ZSTD_STATIC size_t BIT_readBits(BIT_DStream_t *bitD, U32 nbBits) +{ + size_t const value = BIT_lookBits(bitD, nbBits); + BIT_skipBits(bitD, nbBits); + return value; +} + +/*! BIT_readBitsFast() : +* unsafe version; only works only if nbBits >= 1 */ +ZSTD_STATIC size_t BIT_readBitsFast(BIT_DStream_t *bitD, U32 nbBits) +{ + size_t const value = BIT_lookBitsFast(bitD, nbBits); + BIT_skipBits(bitD, nbBits); + return value; +} + +/*! BIT_reloadDStream() : +* Refill `bitD` from buffer previously set in BIT_initDStream() . +* This function is safe, it guarantees it will not read beyond src buffer. +* @return : status of `BIT_DStream_t` internal register. + if status == BIT_DStream_unfinished, internal register is filled with >= (sizeof(bitD->bitContainer)*8 - 7) bits */ +ZSTD_STATIC BIT_DStream_status BIT_reloadDStream(BIT_DStream_t *bitD) +{ + if (bitD->bitsConsumed > (sizeof(bitD->bitContainer) * 8)) /* should not happen => corruption detected */ + return BIT_DStream_overflow; + + if (bitD->ptr >= bitD->start + sizeof(bitD->bitContainer)) { + bitD->ptr -= bitD->bitsConsumed >> 3; + bitD->bitsConsumed &= 7; + bitD->bitContainer = ZSTD_readLEST(bitD->ptr); + return BIT_DStream_unfinished; + } + if (bitD->ptr == bitD->start) { + if (bitD->bitsConsumed < sizeof(bitD->bitContainer) * 8) + return BIT_DStream_endOfBuffer; + return BIT_DStream_completed; + } + { + U32 nbBytes = bitD->bitsConsumed >> 3; + BIT_DStream_status result = BIT_DStream_unfinished; + if (bitD->ptr - nbBytes < bitD->start) { + nbBytes = (U32)(bitD->ptr - bitD->start); /* ptr > start */ + result = BIT_DStream_endOfBuffer; + } + bitD->ptr -= nbBytes; + bitD->bitsConsumed -= nbBytes * 8; + bitD->bitContainer = ZSTD_readLEST(bitD->ptr); /* reminder : srcSize > sizeof(bitD) */ + return result; + } +} + +/*! BIT_endOfDStream() : +* @return Tells if DStream has exactly reached its end (all bits consumed). +*/ +ZSTD_STATIC unsigned BIT_endOfDStream(const BIT_DStream_t *DStream) +{ + return ((DStream->ptr == DStream->start) && (DStream->bitsConsumed == sizeof(DStream->bitContainer) * 8)); +} + +#endif /* BITSTREAM_H_MODULE */ diff --git a/lib/zstd/compress.c b/lib/zstd/compress.c new file mode 100644 index 000000000000..f9166cf4f7a9 --- /dev/null +++ b/lib/zstd/compress.c @@ -0,0 +1,3484 @@ +/** + * Copyright (c) 2016-present, Yann Collet, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of https://github.com/facebook/zstd. + * An additional grant of patent rights can be found in the PATENTS file in the + * same directory. + * + * This program is free software; you can redistribute it and/or modify it under + * the terms of the GNU General Public License version 2 as published by the + * Free Software Foundation. This program is dual-licensed; you may select + * either version 2 of the GNU General Public License ("GPL") or BSD license + * ("BSD"). + */ + +/*-************************************* +* Dependencies +***************************************/ +#include "fse.h" +#include "huf.h" +#include "mem.h" +#include "zstd_internal.h" /* includes zstd.h */ +#include +#include +#include /* memset */ + +/*-************************************* +* Constants +***************************************/ +static const U32 g_searchStrength = 8; /* control skip over incompressible data */ +#define HASH_READ_SIZE 8 +typedef enum { ZSTDcs_created = 0, ZSTDcs_init, ZSTDcs_ongoing, ZSTDcs_ending } ZSTD_compressionStage_e; + +/*-************************************* +* Helper functions +***************************************/ +size_t ZSTD_compressBound(size_t srcSize) { return FSE_compressBound(srcSize) + 12; } + +/*-************************************* +* Sequence storage +***************************************/ +static void ZSTD_resetSeqStore(seqStore_t *ssPtr) +{ + ssPtr->lit = ssPtr->litStart; + ssPtr->sequences = ssPtr->sequencesStart; + ssPtr->longLengthID = 0; +} + +/*-************************************* +* Context memory management +***************************************/ +struct ZSTD_CCtx_s { + const BYTE *nextSrc; /* next block here to continue on curr prefix */ + const BYTE *base; /* All regular indexes relative to this position */ + const BYTE *dictBase; /* extDict indexes relative to this position */ + U32 dictLimit; /* below that point, need extDict */ + U32 lowLimit; /* below that point, no more data */ + U32 nextToUpdate; /* index from which to continue dictionary update */ + U32 nextToUpdate3; /* index from which to continue dictionary update */ + U32 hashLog3; /* dispatch table : larger == faster, more memory */ + U32 loadedDictEnd; /* index of end of dictionary */ + U32 forceWindow; /* force back-references to respect limit of 1< 3) ? 0 : MIN(ZSTD_HASHLOG3_MAX, cParams.windowLog); + size_t const h3Size = ((size_t)1) << hashLog3; + size_t const tableSpace = (chainSize + hSize + h3Size) * sizeof(U32); + size_t const optSpace = + ((MaxML + 1) + (MaxLL + 1) + (MaxOff + 1) + (1 << Litbits)) * sizeof(U32) + (ZSTD_OPT_NUM + 1) * (sizeof(ZSTD_match_t) + sizeof(ZSTD_optimal_t)); + size_t const workspaceSize = tableSpace + (256 * sizeof(U32)) /* huffTable */ + tokenSpace + + (((cParams.strategy == ZSTD_btopt) || (cParams.strategy == ZSTD_btopt2)) ? optSpace : 0); + + return ZSTD_ALIGN(sizeof(ZSTD_stack)) + ZSTD_ALIGN(sizeof(ZSTD_CCtx)) + ZSTD_ALIGN(workspaceSize); +} + +static ZSTD_CCtx *ZSTD_createCCtx_advanced(ZSTD_customMem customMem) +{ + ZSTD_CCtx *cctx; + if (!customMem.customAlloc || !customMem.customFree) + return NULL; + cctx = (ZSTD_CCtx *)ZSTD_malloc(sizeof(ZSTD_CCtx), customMem); + if (!cctx) + return NULL; + memset(cctx, 0, sizeof(ZSTD_CCtx)); + cctx->customMem = customMem; + return cctx; +} + +ZSTD_CCtx *ZSTD_initCCtx(void *workspace, size_t workspaceSize) +{ + ZSTD_customMem const stackMem = ZSTD_initStack(workspace, workspaceSize); + ZSTD_CCtx *cctx = ZSTD_createCCtx_advanced(stackMem); + if (cctx) { + cctx->workSpace = ZSTD_stackAllocAll(cctx->customMem.opaque, &cctx->workSpaceSize); + } + return cctx; +} + +size_t ZSTD_freeCCtx(ZSTD_CCtx *cctx) +{ + if (cctx == NULL) + return 0; /* support free on NULL */ + ZSTD_free(cctx->workSpace, cctx->customMem); + ZSTD_free(cctx, cctx->customMem); + return 0; /* reserved as a potential error code in the future */ +} + +const seqStore_t *ZSTD_getSeqStore(const ZSTD_CCtx *ctx) /* hidden interface */ { return &(ctx->seqStore); } + +static ZSTD_parameters ZSTD_getParamsFromCCtx(const ZSTD_CCtx *cctx) { return cctx->params; } + +/** ZSTD_checkParams() : + ensure param values remain within authorized range. + @return : 0, or an error code if one value is beyond authorized range */ +size_t ZSTD_checkCParams(ZSTD_compressionParameters cParams) +{ +#define CLAMPCHECK(val, min, max) \ + { \ + if ((val < min) | (val > max)) \ + return ERROR(compressionParameter_unsupported); \ + } + CLAMPCHECK(cParams.windowLog, ZSTD_WINDOWLOG_MIN, ZSTD_WINDOWLOG_MAX); + CLAMPCHECK(cParams.chainLog, ZSTD_CHAINLOG_MIN, ZSTD_CHAINLOG_MAX); + CLAMPCHECK(cParams.hashLog, ZSTD_HASHLOG_MIN, ZSTD_HASHLOG_MAX); + CLAMPCHECK(cParams.searchLog, ZSTD_SEARCHLOG_MIN, ZSTD_SEARCHLOG_MAX); + CLAMPCHECK(cParams.searchLength, ZSTD_SEARCHLENGTH_MIN, ZSTD_SEARCHLENGTH_MAX); + CLAMPCHECK(cParams.targetLength, ZSTD_TARGETLENGTH_MIN, ZSTD_TARGETLENGTH_MAX); + if ((U32)(cParams.strategy) > (U32)ZSTD_btopt2) + return ERROR(compressionParameter_unsupported); + return 0; +} + +/** ZSTD_cycleLog() : + * condition for correct operation : hashLog > 1 */ +static U32 ZSTD_cycleLog(U32 hashLog, ZSTD_strategy strat) +{ + U32 const btScale = ((U32)strat >= (U32)ZSTD_btlazy2); + return hashLog - btScale; +} + +/** ZSTD_adjustCParams() : + optimize `cPar` for a given input (`srcSize` and `dictSize`). + mostly downsizing to reduce memory consumption and initialization. + Both `srcSize` and `dictSize` are optional (use 0 if unknown), + but if both are 0, no optimization can be done. + Note : cPar is considered validated at this stage. Use ZSTD_checkParams() to ensure that. */ +ZSTD_compressionParameters ZSTD_adjustCParams(ZSTD_compressionParameters cPar, unsigned long long srcSize, size_t dictSize) +{ + if (srcSize + dictSize == 0) + return cPar; /* no size information available : no adjustment */ + + /* resize params, to use less memory when necessary */ + { + U32 const minSrcSize = (srcSize == 0) ? 500 : 0; + U64 const rSize = srcSize + dictSize + minSrcSize; + if (rSize < ((U64)1 << ZSTD_WINDOWLOG_MAX)) { + U32 const srcLog = MAX(ZSTD_HASHLOG_MIN, ZSTD_highbit32((U32)(rSize)-1) + 1); + if (cPar.windowLog > srcLog) + cPar.windowLog = srcLog; + } + } + if (cPar.hashLog > cPar.windowLog) + cPar.hashLog = cPar.windowLog; + { + U32 const cycleLog = ZSTD_cycleLog(cPar.chainLog, cPar.strategy); + if (cycleLog > cPar.windowLog) + cPar.chainLog -= (cycleLog - cPar.windowLog); + } + + if (cPar.windowLog < ZSTD_WINDOWLOG_ABSOLUTEMIN) + cPar.windowLog = ZSTD_WINDOWLOG_ABSOLUTEMIN; /* required for frame header */ + + return cPar; +} + +static U32 ZSTD_equivalentParams(ZSTD_parameters param1, ZSTD_parameters param2) +{ + return (param1.cParams.hashLog == param2.cParams.hashLog) & (param1.cParams.chainLog == param2.cParams.chainLog) & + (param1.cParams.strategy == param2.cParams.strategy) & ((param1.cParams.searchLength == 3) == (param2.cParams.searchLength == 3)); +} + +/*! ZSTD_continueCCtx() : + reuse CCtx without reset (note : requires no dictionary) */ +static size_t ZSTD_continueCCtx(ZSTD_CCtx *cctx, ZSTD_parameters params, U64 frameContentSize) +{ + U32 const end = (U32)(cctx->nextSrc - cctx->base); + cctx->params = params; + cctx->frameContentSize = frameContentSize; + cctx->lowLimit = end; + cctx->dictLimit = end; + cctx->nextToUpdate = end + 1; + cctx->stage = ZSTDcs_init; + cctx->dictID = 0; + cctx->loadedDictEnd = 0; + { + int i; + for (i = 0; i < ZSTD_REP_NUM; i++) + cctx->rep[i] = repStartValue[i]; + } + cctx->seqStore.litLengthSum = 0; /* force reset of btopt stats */ + xxh64_reset(&cctx->xxhState, 0); + return 0; +} + +typedef enum { ZSTDcrp_continue, ZSTDcrp_noMemset, ZSTDcrp_fullReset } ZSTD_compResetPolicy_e; + +/*! ZSTD_resetCCtx_advanced() : + note : `params` must be validated */ +static size_t ZSTD_resetCCtx_advanced(ZSTD_CCtx *zc, ZSTD_parameters params, U64 frameContentSize, ZSTD_compResetPolicy_e const crp) +{ + if (crp == ZSTDcrp_continue) + if (ZSTD_equivalentParams(params, zc->params)) { + zc->flagStaticTables = 0; + zc->flagStaticHufTable = HUF_repeat_none; + return ZSTD_continueCCtx(zc, params, frameContentSize); + } + + { + size_t const blockSize = MIN(ZSTD_BLOCKSIZE_ABSOLUTEMAX, (size_t)1 << params.cParams.windowLog); + U32 const divider = (params.cParams.searchLength == 3) ? 3 : 4; + size_t const maxNbSeq = blockSize / divider; + size_t const tokenSpace = blockSize + 11 * maxNbSeq; + size_t const chainSize = (params.cParams.strategy == ZSTD_fast) ? 0 : (1 << params.cParams.chainLog); + size_t const hSize = ((size_t)1) << params.cParams.hashLog; + U32 const hashLog3 = (params.cParams.searchLength > 3) ? 0 : MIN(ZSTD_HASHLOG3_MAX, params.cParams.windowLog); + size_t const h3Size = ((size_t)1) << hashLog3; + size_t const tableSpace = (chainSize + hSize + h3Size) * sizeof(U32); + void *ptr; + + /* Check if workSpace is large enough, alloc a new one if needed */ + { + size_t const optSpace = ((MaxML + 1) + (MaxLL + 1) + (MaxOff + 1) + (1 << Litbits)) * sizeof(U32) + + (ZSTD_OPT_NUM + 1) * (sizeof(ZSTD_match_t) + sizeof(ZSTD_optimal_t)); + size_t const neededSpace = tableSpace + (256 * sizeof(U32)) /* huffTable */ + tokenSpace + + (((params.cParams.strategy == ZSTD_btopt) || (params.cParams.strategy == ZSTD_btopt2)) ? optSpace : 0); + if (zc->workSpaceSize < neededSpace) { + ZSTD_free(zc->workSpace, zc->customMem); + zc->workSpace = ZSTD_malloc(neededSpace, zc->customMem); + if (zc->workSpace == NULL) + return ERROR(memory_allocation); + zc->workSpaceSize = neededSpace; + } + } + + if (crp != ZSTDcrp_noMemset) + memset(zc->workSpace, 0, tableSpace); /* reset tables only */ + xxh64_reset(&zc->xxhState, 0); + zc->hashLog3 = hashLog3; + zc->hashTable = (U32 *)(zc->workSpace); + zc->chainTable = zc->hashTable + hSize; + zc->hashTable3 = zc->chainTable + chainSize; + ptr = zc->hashTable3 + h3Size; + zc->hufTable = (HUF_CElt *)ptr; + zc->flagStaticTables = 0; + zc->flagStaticHufTable = HUF_repeat_none; + ptr = ((U32 *)ptr) + 256; /* note : HUF_CElt* is incomplete type, size is simulated using U32 */ + + zc->nextToUpdate = 1; + zc->nextSrc = NULL; + zc->base = NULL; + zc->dictBase = NULL; + zc->dictLimit = 0; + zc->lowLimit = 0; + zc->params = params; + zc->blockSize = blockSize; + zc->frameContentSize = frameContentSize; + { + int i; + for (i = 0; i < ZSTD_REP_NUM; i++) + zc->rep[i] = repStartValue[i]; + } + + if ((params.cParams.strategy == ZSTD_btopt) || (params.cParams.strategy == ZSTD_btopt2)) { + zc->seqStore.litFreq = (U32 *)ptr; + zc->seqStore.litLengthFreq = zc->seqStore.litFreq + (1 << Litbits); + zc->seqStore.matchLengthFreq = zc->seqStore.litLengthFreq + (MaxLL + 1); + zc->seqStore.offCodeFreq = zc->seqStore.matchLengthFreq + (MaxML + 1); + ptr = zc->seqStore.offCodeFreq + (MaxOff + 1); + zc->seqStore.matchTable = (ZSTD_match_t *)ptr; + ptr = zc->seqStore.matchTable + ZSTD_OPT_NUM + 1; + zc->seqStore.priceTable = (ZSTD_optimal_t *)ptr; + ptr = zc->seqStore.priceTable + ZSTD_OPT_NUM + 1; + zc->seqStore.litLengthSum = 0; + } + zc->seqStore.sequencesStart = (seqDef *)ptr; + ptr = zc->seqStore.sequencesStart + maxNbSeq; + zc->seqStore.llCode = (BYTE *)ptr; + zc->seqStore.mlCode = zc->seqStore.llCode + maxNbSeq; + zc->seqStore.ofCode = zc->seqStore.mlCode + maxNbSeq; + zc->seqStore.litStart = zc->seqStore.ofCode + maxNbSeq; + + zc->stage = ZSTDcs_init; + zc->dictID = 0; + zc->loadedDictEnd = 0; + + return 0; + } +} + +/* ZSTD_invalidateRepCodes() : + * ensures next compression will not use repcodes from previous block. + * Note : only works with regular variant; + * do not use with extDict variant ! */ +void ZSTD_invalidateRepCodes(ZSTD_CCtx *cctx) +{ + int i; + for (i = 0; i < ZSTD_REP_NUM; i++) + cctx->rep[i] = 0; +} + +/*! ZSTD_copyCCtx() : +* Duplicate an existing context `srcCCtx` into another one `dstCCtx`. +* Only works during stage ZSTDcs_init (i.e. after creation, but before first call to ZSTD_compressContinue()). +* @return : 0, or an error code */ +size_t ZSTD_copyCCtx(ZSTD_CCtx *dstCCtx, const ZSTD_CCtx *srcCCtx, unsigned long long pledgedSrcSize) +{ + if (srcCCtx->stage != ZSTDcs_init) + return ERROR(stage_wrong); + + memcpy(&dstCCtx->customMem, &srcCCtx->customMem, sizeof(ZSTD_customMem)); + { + ZSTD_parameters params = srcCCtx->params; + params.fParams.contentSizeFlag = (pledgedSrcSize > 0); + ZSTD_resetCCtx_advanced(dstCCtx, params, pledgedSrcSize, ZSTDcrp_noMemset); + } + + /* copy tables */ + { + size_t const chainSize = (srcCCtx->params.cParams.strategy == ZSTD_fast) ? 0 : (1 << srcCCtx->params.cParams.chainLog); + size_t const hSize = ((size_t)1) << srcCCtx->params.cParams.hashLog; + size_t const h3Size = (size_t)1 << srcCCtx->hashLog3; + size_t const tableSpace = (chainSize + hSize + h3Size) * sizeof(U32); + memcpy(dstCCtx->workSpace, srcCCtx->workSpace, tableSpace); + } + + /* copy dictionary offsets */ + dstCCtx->nextToUpdate = srcCCtx->nextToUpdate; + dstCCtx->nextToUpdate3 = srcCCtx->nextToUpdate3; + dstCCtx->nextSrc = srcCCtx->nextSrc; + dstCCtx->base = srcCCtx->base; + dstCCtx->dictBase = srcCCtx->dictBase; + dstCCtx->dictLimit = srcCCtx->dictLimit; + dstCCtx->lowLimit = srcCCtx->lowLimit; + dstCCtx->loadedDictEnd = srcCCtx->loadedDictEnd; + dstCCtx->dictID = srcCCtx->dictID; + + /* copy entropy tables */ + dstCCtx->flagStaticTables = srcCCtx->flagStaticTables; + dstCCtx->flagStaticHufTable = srcCCtx->flagStaticHufTable; + if (srcCCtx->flagStaticTables) { + memcpy(dstCCtx->litlengthCTable, srcCCtx->litlengthCTable, sizeof(dstCCtx->litlengthCTable)); + memcpy(dstCCtx->matchlengthCTable, srcCCtx->matchlengthCTable, sizeof(dstCCtx->matchlengthCTable)); + memcpy(dstCCtx->offcodeCTable, srcCCtx->offcodeCTable, sizeof(dstCCtx->offcodeCTable)); + } + if (srcCCtx->flagStaticHufTable) { + memcpy(dstCCtx->hufTable, srcCCtx->hufTable, 256 * 4); + } + + return 0; +} + +/*! ZSTD_reduceTable() : +* reduce table indexes by `reducerValue` */ +static void ZSTD_reduceTable(U32 *const table, U32 const size, U32 const reducerValue) +{ + U32 u; + for (u = 0; u < size; u++) { + if (table[u] < reducerValue) + table[u] = 0; + else + table[u] -= reducerValue; + } +} + +/*! ZSTD_reduceIndex() : +* rescale all indexes to avoid future overflow (indexes are U32) */ +static void ZSTD_reduceIndex(ZSTD_CCtx *zc, const U32 reducerValue) +{ + { + U32 const hSize = 1 << zc->params.cParams.hashLog; + ZSTD_reduceTable(zc->hashTable, hSize, reducerValue); + } + + { + U32 const chainSize = (zc->params.cParams.strategy == ZSTD_fast) ? 0 : (1 << zc->params.cParams.chainLog); + ZSTD_reduceTable(zc->chainTable, chainSize, reducerValue); + } + + { + U32 const h3Size = (zc->hashLog3) ? 1 << zc->hashLog3 : 0; + ZSTD_reduceTable(zc->hashTable3, h3Size, reducerValue); + } +} + +/*-******************************************************* +* Block entropic compression +*********************************************************/ + +/* See doc/zstd_compression_format.md for detailed format description */ + +size_t ZSTD_noCompressBlock(void *dst, size_t dstCapacity, const void *src, size_t srcSize) +{ + if (srcSize + ZSTD_blockHeaderSize > dstCapacity) + return ERROR(dstSize_tooSmall); + memcpy((BYTE *)dst + ZSTD_blockHeaderSize, src, srcSize); + ZSTD_writeLE24(dst, (U32)(srcSize << 2) + (U32)bt_raw); + return ZSTD_blockHeaderSize + srcSize; +} + +static size_t ZSTD_noCompressLiterals(void *dst, size_t dstCapacity, const void *src, size_t srcSize) +{ + BYTE *const ostart = (BYTE * const)dst; + U32 const flSize = 1 + (srcSize > 31) + (srcSize > 4095); + + if (srcSize + flSize > dstCapacity) + return ERROR(dstSize_tooSmall); + + switch (flSize) { + case 1: /* 2 - 1 - 5 */ ostart[0] = (BYTE)((U32)set_basic + (srcSize << 3)); break; + case 2: /* 2 - 2 - 12 */ ZSTD_writeLE16(ostart, (U16)((U32)set_basic + (1 << 2) + (srcSize << 4))); break; + default: /*note : should not be necessary : flSize is within {1,2,3} */ + case 3: /* 2 - 2 - 20 */ ZSTD_writeLE32(ostart, (U32)((U32)set_basic + (3 << 2) + (srcSize << 4))); break; + } + + memcpy(ostart + flSize, src, srcSize); + return srcSize + flSize; +} + +static size_t ZSTD_compressRleLiteralsBlock(void *dst, size_t dstCapacity, const void *src, size_t srcSize) +{ + BYTE *const ostart = (BYTE * const)dst; + U32 const flSize = 1 + (srcSize > 31) + (srcSize > 4095); + + (void)dstCapacity; /* dstCapacity already guaranteed to be >=4, hence large enough */ + + switch (flSize) { + case 1: /* 2 - 1 - 5 */ ostart[0] = (BYTE)((U32)set_rle + (srcSize << 3)); break; + case 2: /* 2 - 2 - 12 */ ZSTD_writeLE16(ostart, (U16)((U32)set_rle + (1 << 2) + (srcSize << 4))); break; + default: /*note : should not be necessary : flSize is necessarily within {1,2,3} */ + case 3: /* 2 - 2 - 20 */ ZSTD_writeLE32(ostart, (U32)((U32)set_rle + (3 << 2) + (srcSize << 4))); break; + } + + ostart[flSize] = *(const BYTE *)src; + return flSize + 1; +} + +static size_t ZSTD_minGain(size_t srcSize) { return (srcSize >> 6) + 2; } + +static size_t ZSTD_compressLiterals(ZSTD_CCtx *zc, void *dst, size_t dstCapacity, const void *src, size_t srcSize) +{ + size_t const minGain = ZSTD_minGain(srcSize); + size_t const lhSize = 3 + (srcSize >= 1 KB) + (srcSize >= 16 KB); + BYTE *const ostart = (BYTE *)dst; + U32 singleStream = srcSize < 256; + symbolEncodingType_e hType = set_compressed; + size_t cLitSize; + +/* small ? don't even attempt compression (speed opt) */ +#define LITERAL_NOENTROPY 63 + { + size_t const minLitSize = zc->flagStaticHufTable == HUF_repeat_valid ? 6 : LITERAL_NOENTROPY; + if (srcSize <= minLitSize) + return ZSTD_noCompressLiterals(dst, dstCapacity, src, srcSize); + } + + if (dstCapacity < lhSize + 1) + return ERROR(dstSize_tooSmall); /* not enough space for compression */ + { + HUF_repeat repeat = zc->flagStaticHufTable; + int const preferRepeat = zc->params.cParams.strategy < ZSTD_lazy ? srcSize <= 1024 : 0; + if (repeat == HUF_repeat_valid && lhSize == 3) + singleStream = 1; + cLitSize = singleStream ? HUF_compress1X_repeat(ostart + lhSize, dstCapacity - lhSize, src, srcSize, 255, 11, zc->tmpCounters, + sizeof(zc->tmpCounters), zc->hufTable, &repeat, preferRepeat) + : HUF_compress4X_repeat(ostart + lhSize, dstCapacity - lhSize, src, srcSize, 255, 11, zc->tmpCounters, + sizeof(zc->tmpCounters), zc->hufTable, &repeat, preferRepeat); + if (repeat != HUF_repeat_none) { + hType = set_repeat; + } /* reused the existing table */ + else { + zc->flagStaticHufTable = HUF_repeat_check; + } /* now have a table to reuse */ + } + + if ((cLitSize == 0) | (cLitSize >= srcSize - minGain)) { + zc->flagStaticHufTable = HUF_repeat_none; + return ZSTD_noCompressLiterals(dst, dstCapacity, src, srcSize); + } + if (cLitSize == 1) { + zc->flagStaticHufTable = HUF_repeat_none; + return ZSTD_compressRleLiteralsBlock(dst, dstCapacity, src, srcSize); + } + + /* Build header */ + switch (lhSize) { + case 3: /* 2 - 2 - 10 - 10 */ + { + U32 const lhc = hType + ((!singleStream) << 2) + ((U32)srcSize << 4) + ((U32)cLitSize << 14); + ZSTD_writeLE24(ostart, lhc); + break; + } + case 4: /* 2 - 2 - 14 - 14 */ + { + U32 const lhc = hType + (2 << 2) + ((U32)srcSize << 4) + ((U32)cLitSize << 18); + ZSTD_writeLE32(ostart, lhc); + break; + } + default: /* should not be necessary, lhSize is only {3,4,5} */ + case 5: /* 2 - 2 - 18 - 18 */ + { + U32 const lhc = hType + (3 << 2) + ((U32)srcSize << 4) + ((U32)cLitSize << 22); + ZSTD_writeLE32(ostart, lhc); + ostart[4] = (BYTE)(cLitSize >> 10); + break; + } + } + return lhSize + cLitSize; +} + +static const BYTE LL_Code[64] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 16, 17, 17, 18, 18, + 19, 19, 20, 20, 20, 20, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 22, 23, 23, 23, 23, + 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24}; + +static const BYTE ML_Code[128] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + 26, 27, 28, 29, 30, 31, 32, 32, 33, 33, 34, 34, 35, 35, 36, 36, 36, 36, 37, 37, 37, 37, 38, 38, 38, 38, + 38, 38, 38, 38, 39, 39, 39, 39, 39, 39, 39, 39, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, + 40, 40, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 42, 42, 42, 42, 42, 42, 42, 42, + 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42}; + +void ZSTD_seqToCodes(const seqStore_t *seqStorePtr) +{ + BYTE const LL_deltaCode = 19; + BYTE const ML_deltaCode = 36; + const seqDef *const sequences = seqStorePtr->sequencesStart; + BYTE *const llCodeTable = seqStorePtr->llCode; + BYTE *const ofCodeTable = seqStorePtr->ofCode; + BYTE *const mlCodeTable = seqStorePtr->mlCode; + U32 const nbSeq = (U32)(seqStorePtr->sequences - seqStorePtr->sequencesStart); + U32 u; + for (u = 0; u < nbSeq; u++) { + U32 const llv = sequences[u].litLength; + U32 const mlv = sequences[u].matchLength; + llCodeTable[u] = (llv > 63) ? (BYTE)ZSTD_highbit32(llv) + LL_deltaCode : LL_Code[llv]; + ofCodeTable[u] = (BYTE)ZSTD_highbit32(sequences[u].offset); + mlCodeTable[u] = (mlv > 127) ? (BYTE)ZSTD_highbit32(mlv) + ML_deltaCode : ML_Code[mlv]; + } + if (seqStorePtr->longLengthID == 1) + llCodeTable[seqStorePtr->longLengthPos] = MaxLL; + if (seqStorePtr->longLengthID == 2) + mlCodeTable[seqStorePtr->longLengthPos] = MaxML; +} + +ZSTD_STATIC size_t ZSTD_compressSequences_internal(ZSTD_CCtx *zc, void *dst, size_t dstCapacity) +{ + const int longOffsets = zc->params.cParams.windowLog > STREAM_ACCUMULATOR_MIN; + const seqStore_t *seqStorePtr = &(zc->seqStore); + FSE_CTable *CTable_LitLength = zc->litlengthCTable; + FSE_CTable *CTable_OffsetBits = zc->offcodeCTable; + FSE_CTable *CTable_MatchLength = zc->matchlengthCTable; + U32 LLtype, Offtype, MLtype; /* compressed, raw or rle */ + const seqDef *const sequences = seqStorePtr->sequencesStart; + const BYTE *const ofCodeTable = seqStorePtr->ofCode; + const BYTE *const llCodeTable = seqStorePtr->llCode; + const BYTE *const mlCodeTable = seqStorePtr->mlCode; + BYTE *const ostart = (BYTE *)dst; + BYTE *const oend = ostart + dstCapacity; + BYTE *op = ostart; + size_t const nbSeq = seqStorePtr->sequences - seqStorePtr->sequencesStart; + BYTE *seqHead; + + U32 *count; + S16 *norm; + U32 *workspace; + size_t workspaceSize = sizeof(zc->tmpCounters); + { + size_t spaceUsed32 = 0; + count = (U32 *)zc->tmpCounters + spaceUsed32; + spaceUsed32 += MaxSeq + 1; + norm = (S16 *)((U32 *)zc->tmpCounters + spaceUsed32); + spaceUsed32 += ALIGN(sizeof(S16) * (MaxSeq + 1), sizeof(U32)) >> 2; + + workspace = (U32 *)zc->tmpCounters + spaceUsed32; + workspaceSize -= (spaceUsed32 << 2); + } + + /* Compress literals */ + { + const BYTE *const literals = seqStorePtr->litStart; + size_t const litSize = seqStorePtr->lit - literals; + size_t const cSize = ZSTD_compressLiterals(zc, op, dstCapacity, literals, litSize); + if (ZSTD_isError(cSize)) + return cSize; + op += cSize; + } + + /* Sequences Header */ + if ((oend - op) < 3 /*max nbSeq Size*/ + 1 /*seqHead */) + return ERROR(dstSize_tooSmall); + if (nbSeq < 0x7F) + *op++ = (BYTE)nbSeq; + else if (nbSeq < LONGNBSEQ) + op[0] = (BYTE)((nbSeq >> 8) + 0x80), op[1] = (BYTE)nbSeq, op += 2; + else + op[0] = 0xFF, ZSTD_writeLE16(op + 1, (U16)(nbSeq - LONGNBSEQ)), op += 3; + if (nbSeq == 0) + return op - ostart; + + /* seqHead : flags for FSE encoding type */ + seqHead = op++; + +#define MIN_SEQ_FOR_DYNAMIC_FSE 64 +#define MAX_SEQ_FOR_STATIC_FSE 1000 + + /* convert length/distances into codes */ + ZSTD_seqToCodes(seqStorePtr); + + /* CTable for Literal Lengths */ + { + U32 max = MaxLL; + size_t const mostFrequent = FSE_countFast_wksp(count, &max, llCodeTable, nbSeq, workspace); + if ((mostFrequent == nbSeq) && (nbSeq > 2)) { + *op++ = llCodeTable[0]; + FSE_buildCTable_rle(CTable_LitLength, (BYTE)max); + LLtype = set_rle; + } else if ((zc->flagStaticTables) && (nbSeq < MAX_SEQ_FOR_STATIC_FSE)) { + LLtype = set_repeat; + } else if ((nbSeq < MIN_SEQ_FOR_DYNAMIC_FSE) || (mostFrequent < (nbSeq >> (LL_defaultNormLog - 1)))) { + FSE_buildCTable_wksp(CTable_LitLength, LL_defaultNorm, MaxLL, LL_defaultNormLog, workspace, workspaceSize); + LLtype = set_basic; + } else { + size_t nbSeq_1 = nbSeq; + const U32 tableLog = FSE_optimalTableLog(LLFSELog, nbSeq, max); + if (count[llCodeTable[nbSeq - 1]] > 1) { + count[llCodeTable[nbSeq - 1]]--; + nbSeq_1--; + } + FSE_normalizeCount(norm, tableLog, count, nbSeq_1, max); + { + size_t const NCountSize = FSE_writeNCount(op, oend - op, norm, max, tableLog); /* overflow protected */ + if (FSE_isError(NCountSize)) + return NCountSize; + op += NCountSize; + } + FSE_buildCTable_wksp(CTable_LitLength, norm, max, tableLog, workspace, workspaceSize); + LLtype = set_compressed; + } + } + + /* CTable for Offsets */ + { + U32 max = MaxOff; + size_t const mostFrequent = FSE_countFast_wksp(count, &max, ofCodeTable, nbSeq, workspace); + if ((mostFrequent == nbSeq) && (nbSeq > 2)) { + *op++ = ofCodeTable[0]; + FSE_buildCTable_rle(CTable_OffsetBits, (BYTE)max); + Offtype = set_rle; + } else if ((zc->flagStaticTables) && (nbSeq < MAX_SEQ_FOR_STATIC_FSE)) { + Offtype = set_repeat; + } else if ((nbSeq < MIN_SEQ_FOR_DYNAMIC_FSE) || (mostFrequent < (nbSeq >> (OF_defaultNormLog - 1)))) { + FSE_buildCTable_wksp(CTable_OffsetBits, OF_defaultNorm, MaxOff, OF_defaultNormLog, workspace, workspaceSize); + Offtype = set_basic; + } else { + size_t nbSeq_1 = nbSeq; + const U32 tableLog = FSE_optimalTableLog(OffFSELog, nbSeq, max); + if (count[ofCodeTable[nbSeq - 1]] > 1) { + count[ofCodeTable[nbSeq - 1]]--; + nbSeq_1--; + } + FSE_normalizeCount(norm, tableLog, count, nbSeq_1, max); + { + size_t const NCountSize = FSE_writeNCount(op, oend - op, norm, max, tableLog); /* overflow protected */ + if (FSE_isError(NCountSize)) + return NCountSize; + op += NCountSize; + } + FSE_buildCTable_wksp(CTable_OffsetBits, norm, max, tableLog, workspace, workspaceSize); + Offtype = set_compressed; + } + } + + /* CTable for MatchLengths */ + { + U32 max = MaxML; + size_t const mostFrequent = FSE_countFast_wksp(count, &max, mlCodeTable, nbSeq, workspace); + if ((mostFrequent == nbSeq) && (nbSeq > 2)) { + *op++ = *mlCodeTable; + FSE_buildCTable_rle(CTable_MatchLength, (BYTE)max); + MLtype = set_rle; + } else if ((zc->flagStaticTables) && (nbSeq < MAX_SEQ_FOR_STATIC_FSE)) { + MLtype = set_repeat; + } else if ((nbSeq < MIN_SEQ_FOR_DYNAMIC_FSE) || (mostFrequent < (nbSeq >> (ML_defaultNormLog - 1)))) { + FSE_buildCTable_wksp(CTable_MatchLength, ML_defaultNorm, MaxML, ML_defaultNormLog, workspace, workspaceSize); + MLtype = set_basic; + } else { + size_t nbSeq_1 = nbSeq; + const U32 tableLog = FSE_optimalTableLog(MLFSELog, nbSeq, max); + if (count[mlCodeTable[nbSeq - 1]] > 1) { + count[mlCodeTable[nbSeq - 1]]--; + nbSeq_1--; + } + FSE_normalizeCount(norm, tableLog, count, nbSeq_1, max); + { + size_t const NCountSize = FSE_writeNCount(op, oend - op, norm, max, tableLog); /* overflow protected */ + if (FSE_isError(NCountSize)) + return NCountSize; + op += NCountSize; + } + FSE_buildCTable_wksp(CTable_MatchLength, norm, max, tableLog, workspace, workspaceSize); + MLtype = set_compressed; + } + } + + *seqHead = (BYTE)((LLtype << 6) + (Offtype << 4) + (MLtype << 2)); + zc->flagStaticTables = 0; + + /* Encoding Sequences */ + { + BIT_CStream_t blockStream; + FSE_CState_t stateMatchLength; + FSE_CState_t stateOffsetBits; + FSE_CState_t stateLitLength; + + CHECK_E(BIT_initCStream(&blockStream, op, oend - op), dstSize_tooSmall); /* not enough space remaining */ + + /* first symbols */ + FSE_initCState2(&stateMatchLength, CTable_MatchLength, mlCodeTable[nbSeq - 1]); + FSE_initCState2(&stateOffsetBits, CTable_OffsetBits, ofCodeTable[nbSeq - 1]); + FSE_initCState2(&stateLitLength, CTable_LitLength, llCodeTable[nbSeq - 1]); + BIT_addBits(&blockStream, sequences[nbSeq - 1].litLength, LL_bits[llCodeTable[nbSeq - 1]]); + if (ZSTD_32bits()) + BIT_flushBits(&blockStream); + BIT_addBits(&blockStream, sequences[nbSeq - 1].matchLength, ML_bits[mlCodeTable[nbSeq - 1]]); + if (ZSTD_32bits()) + BIT_flushBits(&blockStream); + if (longOffsets) { + U32 const ofBits = ofCodeTable[nbSeq - 1]; + int const extraBits = ofBits - MIN(ofBits, STREAM_ACCUMULATOR_MIN - 1); + if (extraBits) { + BIT_addBits(&blockStream, sequences[nbSeq - 1].offset, extraBits); + BIT_flushBits(&blockStream); + } + BIT_addBits(&blockStream, sequences[nbSeq - 1].offset >> extraBits, ofBits - extraBits); + } else { + BIT_addBits(&blockStream, sequences[nbSeq - 1].offset, ofCodeTable[nbSeq - 1]); + } + BIT_flushBits(&blockStream); + + { + size_t n; + for (n = nbSeq - 2; n < nbSeq; n--) { /* intentional underflow */ + BYTE const llCode = llCodeTable[n]; + BYTE const ofCode = ofCodeTable[n]; + BYTE const mlCode = mlCodeTable[n]; + U32 const llBits = LL_bits[llCode]; + U32 const ofBits = ofCode; /* 32b*/ /* 64b*/ + U32 const mlBits = ML_bits[mlCode]; + /* (7)*/ /* (7)*/ + FSE_encodeSymbol(&blockStream, &stateOffsetBits, ofCode); /* 15 */ /* 15 */ + FSE_encodeSymbol(&blockStream, &stateMatchLength, mlCode); /* 24 */ /* 24 */ + if (ZSTD_32bits()) + BIT_flushBits(&blockStream); /* (7)*/ + FSE_encodeSymbol(&blockStream, &stateLitLength, llCode); /* 16 */ /* 33 */ + if (ZSTD_32bits() || (ofBits + mlBits + llBits >= 64 - 7 - (LLFSELog + MLFSELog + OffFSELog))) + BIT_flushBits(&blockStream); /* (7)*/ + BIT_addBits(&blockStream, sequences[n].litLength, llBits); + if (ZSTD_32bits() && ((llBits + mlBits) > 24)) + BIT_flushBits(&blockStream); + BIT_addBits(&blockStream, sequences[n].matchLength, mlBits); + if (ZSTD_32bits()) + BIT_flushBits(&blockStream); /* (7)*/ + if (longOffsets) { + int const extraBits = ofBits - MIN(ofBits, STREAM_ACCUMULATOR_MIN - 1); + if (extraBits) { + BIT_addBits(&blockStream, sequences[n].offset, extraBits); + BIT_flushBits(&blockStream); /* (7)*/ + } + BIT_addBits(&blockStream, sequences[n].offset >> extraBits, ofBits - extraBits); /* 31 */ + } else { + BIT_addBits(&blockStream, sequences[n].offset, ofBits); /* 31 */ + } + BIT_flushBits(&blockStream); /* (7)*/ + } + } + + FSE_flushCState(&blockStream, &stateMatchLength); + FSE_flushCState(&blockStream, &stateOffsetBits); + FSE_flushCState(&blockStream, &stateLitLength); + + { + size_t const streamSize = BIT_closeCStream(&blockStream); + if (streamSize == 0) + return ERROR(dstSize_tooSmall); /* not enough space */ + op += streamSize; + } + } + return op - ostart; +} + +ZSTD_STATIC size_t ZSTD_compressSequences(ZSTD_CCtx *zc, void *dst, size_t dstCapacity, size_t srcSize) +{ + size_t const cSize = ZSTD_compressSequences_internal(zc, dst, dstCapacity); + size_t const minGain = ZSTD_minGain(srcSize); + size_t const maxCSize = srcSize - minGain; + /* If the srcSize <= dstCapacity, then there is enough space to write a + * raw uncompressed block. Since we ran out of space, the block must not + * be compressible, so fall back to a raw uncompressed block. + */ + int const uncompressibleError = cSize == ERROR(dstSize_tooSmall) && srcSize <= dstCapacity; + int i; + + if (ZSTD_isError(cSize) && !uncompressibleError) + return cSize; + if (cSize >= maxCSize || uncompressibleError) { + zc->flagStaticHufTable = HUF_repeat_none; + return 0; + } + /* confirm repcodes */ + for (i = 0; i < ZSTD_REP_NUM; i++) + zc->rep[i] = zc->repToConfirm[i]; + return cSize; +} + +/*! ZSTD_storeSeq() : + Store a sequence (literal length, literals, offset code and match length code) into seqStore_t. + `offsetCode` : distance to match, or 0 == repCode. + `matchCode` : matchLength - MINMATCH +*/ +ZSTD_STATIC void ZSTD_storeSeq(seqStore_t *seqStorePtr, size_t litLength, const void *literals, U32 offsetCode, size_t matchCode) +{ + /* copy Literals */ + ZSTD_wildcopy(seqStorePtr->lit, literals, litLength); + seqStorePtr->lit += litLength; + + /* literal Length */ + if (litLength > 0xFFFF) { + seqStorePtr->longLengthID = 1; + seqStorePtr->longLengthPos = (U32)(seqStorePtr->sequences - seqStorePtr->sequencesStart); + } + seqStorePtr->sequences[0].litLength = (U16)litLength; + + /* match offset */ + seqStorePtr->sequences[0].offset = offsetCode + 1; + + /* match Length */ + if (matchCode > 0xFFFF) { + seqStorePtr->longLengthID = 2; + seqStorePtr->longLengthPos = (U32)(seqStorePtr->sequences - seqStorePtr->sequencesStart); + } + seqStorePtr->sequences[0].matchLength = (U16)matchCode; + + seqStorePtr->sequences++; +} + +/*-************************************* +* Match length counter +***************************************/ +static unsigned ZSTD_NbCommonBytes(register size_t val) +{ + if (ZSTD_isLittleEndian()) { + if (ZSTD_64bits()) { + return (__builtin_ctzll((U64)val) >> 3); + } else { /* 32 bits */ + return (__builtin_ctz((U32)val) >> 3); + } + } else { /* Big Endian CPU */ + if (ZSTD_64bits()) { + return (__builtin_clzll(val) >> 3); + } else { /* 32 bits */ + return (__builtin_clz((U32)val) >> 3); + } + } +} + +static size_t ZSTD_count(const BYTE *pIn, const BYTE *pMatch, const BYTE *const pInLimit) +{ + const BYTE *const pStart = pIn; + const BYTE *const pInLoopLimit = pInLimit - (sizeof(size_t) - 1); + + while (pIn < pInLoopLimit) { + size_t const diff = ZSTD_readST(pMatch) ^ ZSTD_readST(pIn); + if (!diff) { + pIn += sizeof(size_t); + pMatch += sizeof(size_t); + continue; + } + pIn += ZSTD_NbCommonBytes(diff); + return (size_t)(pIn - pStart); + } + if (ZSTD_64bits()) + if ((pIn < (pInLimit - 3)) && (ZSTD_read32(pMatch) == ZSTD_read32(pIn))) { + pIn += 4; + pMatch += 4; + } + if ((pIn < (pInLimit - 1)) && (ZSTD_read16(pMatch) == ZSTD_read16(pIn))) { + pIn += 2; + pMatch += 2; + } + if ((pIn < pInLimit) && (*pMatch == *pIn)) + pIn++; + return (size_t)(pIn - pStart); +} + +/** ZSTD_count_2segments() : +* can count match length with `ip` & `match` in 2 different segments. +* convention : on reaching mEnd, match count continue starting from iStart +*/ +static size_t ZSTD_count_2segments(const BYTE *ip, const BYTE *match, const BYTE *iEnd, const BYTE *mEnd, const BYTE *iStart) +{ + const BYTE *const vEnd = MIN(ip + (mEnd - match), iEnd); + size_t const matchLength = ZSTD_count(ip, match, vEnd); + if (match + matchLength != mEnd) + return matchLength; + return matchLength + ZSTD_count(ip + matchLength, iStart, iEnd); +} + +/*-************************************* +* Hashes +***************************************/ +static const U32 prime3bytes = 506832829U; +static U32 ZSTD_hash3(U32 u, U32 h) { return ((u << (32 - 24)) * prime3bytes) >> (32 - h); } +ZSTD_STATIC size_t ZSTD_hash3Ptr(const void *ptr, U32 h) { return ZSTD_hash3(ZSTD_readLE32(ptr), h); } /* only in zstd_opt.h */ + +static const U32 prime4bytes = 2654435761U; +static U32 ZSTD_hash4(U32 u, U32 h) { return (u * prime4bytes) >> (32 - h); } +static size_t ZSTD_hash4Ptr(const void *ptr, U32 h) { return ZSTD_hash4(ZSTD_read32(ptr), h); } + +static const U64 prime5bytes = 889523592379ULL; +static size_t ZSTD_hash5(U64 u, U32 h) { return (size_t)(((u << (64 - 40)) * prime5bytes) >> (64 - h)); } +static size_t ZSTD_hash5Ptr(const void *p, U32 h) { return ZSTD_hash5(ZSTD_readLE64(p), h); } + +static const U64 prime6bytes = 227718039650203ULL; +static size_t ZSTD_hash6(U64 u, U32 h) { return (size_t)(((u << (64 - 48)) * prime6bytes) >> (64 - h)); } +static size_t ZSTD_hash6Ptr(const void *p, U32 h) { return ZSTD_hash6(ZSTD_readLE64(p), h); } + +static const U64 prime7bytes = 58295818150454627ULL; +static size_t ZSTD_hash7(U64 u, U32 h) { return (size_t)(((u << (64 - 56)) * prime7bytes) >> (64 - h)); } +static size_t ZSTD_hash7Ptr(const void *p, U32 h) { return ZSTD_hash7(ZSTD_readLE64(p), h); } + +static const U64 prime8bytes = 0xCF1BBCDCB7A56463ULL; +static size_t ZSTD_hash8(U64 u, U32 h) { return (size_t)(((u)*prime8bytes) >> (64 - h)); } +static size_t ZSTD_hash8Ptr(const void *p, U32 h) { return ZSTD_hash8(ZSTD_readLE64(p), h); } + +static size_t ZSTD_hashPtr(const void *p, U32 hBits, U32 mls) +{ + switch (mls) { + // case 3: return ZSTD_hash3Ptr(p, hBits); + default: + case 4: return ZSTD_hash4Ptr(p, hBits); + case 5: return ZSTD_hash5Ptr(p, hBits); + case 6: return ZSTD_hash6Ptr(p, hBits); + case 7: return ZSTD_hash7Ptr(p, hBits); + case 8: return ZSTD_hash8Ptr(p, hBits); + } +} + +/*-************************************* +* Fast Scan +***************************************/ +static void ZSTD_fillHashTable(ZSTD_CCtx *zc, const void *end, const U32 mls) +{ + U32 *const hashTable = zc->hashTable; + U32 const hBits = zc->params.cParams.hashLog; + const BYTE *const base = zc->base; + const BYTE *ip = base + zc->nextToUpdate; + const BYTE *const iend = ((const BYTE *)end) - HASH_READ_SIZE; + const size_t fastHashFillStep = 3; + + while (ip <= iend) { + hashTable[ZSTD_hashPtr(ip, hBits, mls)] = (U32)(ip - base); + ip += fastHashFillStep; + } +} + +FORCE_INLINE +void ZSTD_compressBlock_fast_generic(ZSTD_CCtx *cctx, const void *src, size_t srcSize, const U32 mls) +{ + U32 *const hashTable = cctx->hashTable; + U32 const hBits = cctx->params.cParams.hashLog; + seqStore_t *seqStorePtr = &(cctx->seqStore); + const BYTE *const base = cctx->base; + const BYTE *const istart = (const BYTE *)src; + const BYTE *ip = istart; + const BYTE *anchor = istart; + const U32 lowestIndex = cctx->dictLimit; + const BYTE *const lowest = base + lowestIndex; + const BYTE *const iend = istart + srcSize; + const BYTE *const ilimit = iend - HASH_READ_SIZE; + U32 offset_1 = cctx->rep[0], offset_2 = cctx->rep[1]; + U32 offsetSaved = 0; + + /* init */ + ip += (ip == lowest); + { + U32 const maxRep = (U32)(ip - lowest); + if (offset_2 > maxRep) + offsetSaved = offset_2, offset_2 = 0; + if (offset_1 > maxRep) + offsetSaved = offset_1, offset_1 = 0; + } + + /* Main Search Loop */ + while (ip < ilimit) { /* < instead of <=, because repcode check at (ip+1) */ + size_t mLength; + size_t const h = ZSTD_hashPtr(ip, hBits, mls); + U32 const curr = (U32)(ip - base); + U32 const matchIndex = hashTable[h]; + const BYTE *match = base + matchIndex; + hashTable[h] = curr; /* update hash table */ + + if ((offset_1 > 0) & (ZSTD_read32(ip + 1 - offset_1) == ZSTD_read32(ip + 1))) { + mLength = ZSTD_count(ip + 1 + 4, ip + 1 + 4 - offset_1, iend) + 4; + ip++; + ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, 0, mLength - MINMATCH); + } else { + U32 offset; + if ((matchIndex <= lowestIndex) || (ZSTD_read32(match) != ZSTD_read32(ip))) { + ip += ((ip - anchor) >> g_searchStrength) + 1; + continue; + } + mLength = ZSTD_count(ip + 4, match + 4, iend) + 4; + offset = (U32)(ip - match); + while (((ip > anchor) & (match > lowest)) && (ip[-1] == match[-1])) { + ip--; + match--; + mLength++; + } /* catch up */ + offset_2 = offset_1; + offset_1 = offset; + + ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, offset + ZSTD_REP_MOVE, mLength - MINMATCH); + } + + /* match found */ + ip += mLength; + anchor = ip; + + if (ip <= ilimit) { + /* Fill Table */ + hashTable[ZSTD_hashPtr(base + curr + 2, hBits, mls)] = curr + 2; /* here because curr+2 could be > iend-8 */ + hashTable[ZSTD_hashPtr(ip - 2, hBits, mls)] = (U32)(ip - 2 - base); + /* check immediate repcode */ + while ((ip <= ilimit) && ((offset_2 > 0) & (ZSTD_read32(ip) == ZSTD_read32(ip - offset_2)))) { + /* store sequence */ + size_t const rLength = ZSTD_count(ip + 4, ip + 4 - offset_2, iend) + 4; + { + U32 const tmpOff = offset_2; + offset_2 = offset_1; + offset_1 = tmpOff; + } /* swap offset_2 <=> offset_1 */ + hashTable[ZSTD_hashPtr(ip, hBits, mls)] = (U32)(ip - base); + ZSTD_storeSeq(seqStorePtr, 0, anchor, 0, rLength - MINMATCH); + ip += rLength; + anchor = ip; + continue; /* faster when present ... (?) */ + } + } + } + + /* save reps for next block */ + cctx->repToConfirm[0] = offset_1 ? offset_1 : offsetSaved; + cctx->repToConfirm[1] = offset_2 ? offset_2 : offsetSaved; + + /* Last Literals */ + { + size_t const lastLLSize = iend - anchor; + memcpy(seqStorePtr->lit, anchor, lastLLSize); + seqStorePtr->lit += lastLLSize; + } +} + +static void ZSTD_compressBlock_fast(ZSTD_CCtx *ctx, const void *src, size_t srcSize) +{ + const U32 mls = ctx->params.cParams.searchLength; + switch (mls) { + default: /* includes case 3 */ + case 4: ZSTD_compressBlock_fast_generic(ctx, src, srcSize, 4); return; + case 5: ZSTD_compressBlock_fast_generic(ctx, src, srcSize, 5); return; + case 6: ZSTD_compressBlock_fast_generic(ctx, src, srcSize, 6); return; + case 7: ZSTD_compressBlock_fast_generic(ctx, src, srcSize, 7); return; + } +} + +static void ZSTD_compressBlock_fast_extDict_generic(ZSTD_CCtx *ctx, const void *src, size_t srcSize, const U32 mls) +{ + U32 *hashTable = ctx->hashTable; + const U32 hBits = ctx->params.cParams.hashLog; + seqStore_t *seqStorePtr = &(ctx->seqStore); + const BYTE *const base = ctx->base; + const BYTE *const dictBase = ctx->dictBase; + const BYTE *const istart = (const BYTE *)src; + const BYTE *ip = istart; + const BYTE *anchor = istart; + const U32 lowestIndex = ctx->lowLimit; + const BYTE *const dictStart = dictBase + lowestIndex; + const U32 dictLimit = ctx->dictLimit; + const BYTE *const lowPrefixPtr = base + dictLimit; + const BYTE *const dictEnd = dictBase + dictLimit; + const BYTE *const iend = istart + srcSize; + const BYTE *const ilimit = iend - 8; + U32 offset_1 = ctx->rep[0], offset_2 = ctx->rep[1]; + + /* Search Loop */ + while (ip < ilimit) { /* < instead of <=, because (ip+1) */ + const size_t h = ZSTD_hashPtr(ip, hBits, mls); + const U32 matchIndex = hashTable[h]; + const BYTE *matchBase = matchIndex < dictLimit ? dictBase : base; + const BYTE *match = matchBase + matchIndex; + const U32 curr = (U32)(ip - base); + const U32 repIndex = curr + 1 - offset_1; /* offset_1 expected <= curr +1 */ + const BYTE *repBase = repIndex < dictLimit ? dictBase : base; + const BYTE *repMatch = repBase + repIndex; + size_t mLength; + hashTable[h] = curr; /* update hash table */ + + if ((((U32)((dictLimit - 1) - repIndex) >= 3) /* intentional underflow */ & (repIndex > lowestIndex)) && + (ZSTD_read32(repMatch) == ZSTD_read32(ip + 1))) { + const BYTE *repMatchEnd = repIndex < dictLimit ? dictEnd : iend; + mLength = ZSTD_count_2segments(ip + 1 + EQUAL_READ32, repMatch + EQUAL_READ32, iend, repMatchEnd, lowPrefixPtr) + EQUAL_READ32; + ip++; + ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, 0, mLength - MINMATCH); + } else { + if ((matchIndex < lowestIndex) || (ZSTD_read32(match) != ZSTD_read32(ip))) { + ip += ((ip - anchor) >> g_searchStrength) + 1; + continue; + } + { + const BYTE *matchEnd = matchIndex < dictLimit ? dictEnd : iend; + const BYTE *lowMatchPtr = matchIndex < dictLimit ? dictStart : lowPrefixPtr; + U32 offset; + mLength = ZSTD_count_2segments(ip + EQUAL_READ32, match + EQUAL_READ32, iend, matchEnd, lowPrefixPtr) + EQUAL_READ32; + while (((ip > anchor) & (match > lowMatchPtr)) && (ip[-1] == match[-1])) { + ip--; + match--; + mLength++; + } /* catch up */ + offset = curr - matchIndex; + offset_2 = offset_1; + offset_1 = offset; + ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, offset + ZSTD_REP_MOVE, mLength - MINMATCH); + } + } + + /* found a match : store it */ + ip += mLength; + anchor = ip; + + if (ip <= ilimit) { + /* Fill Table */ + hashTable[ZSTD_hashPtr(base + curr + 2, hBits, mls)] = curr + 2; + hashTable[ZSTD_hashPtr(ip - 2, hBits, mls)] = (U32)(ip - 2 - base); + /* check immediate repcode */ + while (ip <= ilimit) { + U32 const curr2 = (U32)(ip - base); + U32 const repIndex2 = curr2 - offset_2; + const BYTE *repMatch2 = repIndex2 < dictLimit ? dictBase + repIndex2 : base + repIndex2; + if ((((U32)((dictLimit - 1) - repIndex2) >= 3) & (repIndex2 > lowestIndex)) /* intentional overflow */ + && (ZSTD_read32(repMatch2) == ZSTD_read32(ip))) { + const BYTE *const repEnd2 = repIndex2 < dictLimit ? dictEnd : iend; + size_t repLength2 = + ZSTD_count_2segments(ip + EQUAL_READ32, repMatch2 + EQUAL_READ32, iend, repEnd2, lowPrefixPtr) + EQUAL_READ32; + U32 tmpOffset = offset_2; + offset_2 = offset_1; + offset_1 = tmpOffset; /* swap offset_2 <=> offset_1 */ + ZSTD_storeSeq(seqStorePtr, 0, anchor, 0, repLength2 - MINMATCH); + hashTable[ZSTD_hashPtr(ip, hBits, mls)] = curr2; + ip += repLength2; + anchor = ip; + continue; + } + break; + } + } + } + + /* save reps for next block */ + ctx->repToConfirm[0] = offset_1; + ctx->repToConfirm[1] = offset_2; + + /* Last Literals */ + { + size_t const lastLLSize = iend - anchor; + memcpy(seqStorePtr->lit, anchor, lastLLSize); + seqStorePtr->lit += lastLLSize; + } +} + +static void ZSTD_compressBlock_fast_extDict(ZSTD_CCtx *ctx, const void *src, size_t srcSize) +{ + U32 const mls = ctx->params.cParams.searchLength; + switch (mls) { + default: /* includes case 3 */ + case 4: ZSTD_compressBlock_fast_extDict_generic(ctx, src, srcSize, 4); return; + case 5: ZSTD_compressBlock_fast_extDict_generic(ctx, src, srcSize, 5); return; + case 6: ZSTD_compressBlock_fast_extDict_generic(ctx, src, srcSize, 6); return; + case 7: ZSTD_compressBlock_fast_extDict_generic(ctx, src, srcSize, 7); return; + } +} + +/*-************************************* +* Double Fast +***************************************/ +static void ZSTD_fillDoubleHashTable(ZSTD_CCtx *cctx, const void *end, const U32 mls) +{ + U32 *const hashLarge = cctx->hashTable; + U32 const hBitsL = cctx->params.cParams.hashLog; + U32 *const hashSmall = cctx->chainTable; + U32 const hBitsS = cctx->params.cParams.chainLog; + const BYTE *const base = cctx->base; + const BYTE *ip = base + cctx->nextToUpdate; + const BYTE *const iend = ((const BYTE *)end) - HASH_READ_SIZE; + const size_t fastHashFillStep = 3; + + while (ip <= iend) { + hashSmall[ZSTD_hashPtr(ip, hBitsS, mls)] = (U32)(ip - base); + hashLarge[ZSTD_hashPtr(ip, hBitsL, 8)] = (U32)(ip - base); + ip += fastHashFillStep; + } +} + +FORCE_INLINE +void ZSTD_compressBlock_doubleFast_generic(ZSTD_CCtx *cctx, const void *src, size_t srcSize, const U32 mls) +{ + U32 *const hashLong = cctx->hashTable; + const U32 hBitsL = cctx->params.cParams.hashLog; + U32 *const hashSmall = cctx->chainTable; + const U32 hBitsS = cctx->params.cParams.chainLog; + seqStore_t *seqStorePtr = &(cctx->seqStore); + const BYTE *const base = cctx->base; + const BYTE *const istart = (const BYTE *)src; + const BYTE *ip = istart; + const BYTE *anchor = istart; + const U32 lowestIndex = cctx->dictLimit; + const BYTE *const lowest = base + lowestIndex; + const BYTE *const iend = istart + srcSize; + const BYTE *const ilimit = iend - HASH_READ_SIZE; + U32 offset_1 = cctx->rep[0], offset_2 = cctx->rep[1]; + U32 offsetSaved = 0; + + /* init */ + ip += (ip == lowest); + { + U32 const maxRep = (U32)(ip - lowest); + if (offset_2 > maxRep) + offsetSaved = offset_2, offset_2 = 0; + if (offset_1 > maxRep) + offsetSaved = offset_1, offset_1 = 0; + } + + /* Main Search Loop */ + while (ip < ilimit) { /* < instead of <=, because repcode check at (ip+1) */ + size_t mLength; + size_t const h2 = ZSTD_hashPtr(ip, hBitsL, 8); + size_t const h = ZSTD_hashPtr(ip, hBitsS, mls); + U32 const curr = (U32)(ip - base); + U32 const matchIndexL = hashLong[h2]; + U32 const matchIndexS = hashSmall[h]; + const BYTE *matchLong = base + matchIndexL; + const BYTE *match = base + matchIndexS; + hashLong[h2] = hashSmall[h] = curr; /* update hash tables */ + + if ((offset_1 > 0) & (ZSTD_read32(ip + 1 - offset_1) == ZSTD_read32(ip + 1))) { /* note : by construction, offset_1 <= curr */ + mLength = ZSTD_count(ip + 1 + 4, ip + 1 + 4 - offset_1, iend) + 4; + ip++; + ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, 0, mLength - MINMATCH); + } else { + U32 offset; + if ((matchIndexL > lowestIndex) && (ZSTD_read64(matchLong) == ZSTD_read64(ip))) { + mLength = ZSTD_count(ip + 8, matchLong + 8, iend) + 8; + offset = (U32)(ip - matchLong); + while (((ip > anchor) & (matchLong > lowest)) && (ip[-1] == matchLong[-1])) { + ip--; + matchLong--; + mLength++; + } /* catch up */ + } else if ((matchIndexS > lowestIndex) && (ZSTD_read32(match) == ZSTD_read32(ip))) { + size_t const h3 = ZSTD_hashPtr(ip + 1, hBitsL, 8); + U32 const matchIndex3 = hashLong[h3]; + const BYTE *match3 = base + matchIndex3; + hashLong[h3] = curr + 1; + if ((matchIndex3 > lowestIndex) && (ZSTD_read64(match3) == ZSTD_read64(ip + 1))) { + mLength = ZSTD_count(ip + 9, match3 + 8, iend) + 8; + ip++; + offset = (U32)(ip - match3); + while (((ip > anchor) & (match3 > lowest)) && (ip[-1] == match3[-1])) { + ip--; + match3--; + mLength++; + } /* catch up */ + } else { + mLength = ZSTD_count(ip + 4, match + 4, iend) + 4; + offset = (U32)(ip - match); + while (((ip > anchor) & (match > lowest)) && (ip[-1] == match[-1])) { + ip--; + match--; + mLength++; + } /* catch up */ + } + } else { + ip += ((ip - anchor) >> g_searchStrength) + 1; + continue; + } + + offset_2 = offset_1; + offset_1 = offset; + + ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, offset + ZSTD_REP_MOVE, mLength - MINMATCH); + } + + /* match found */ + ip += mLength; + anchor = ip; + + if (ip <= ilimit) { + /* Fill Table */ + hashLong[ZSTD_hashPtr(base + curr + 2, hBitsL, 8)] = hashSmall[ZSTD_hashPtr(base + curr + 2, hBitsS, mls)] = + curr + 2; /* here because curr+2 could be > iend-8 */ + hashLong[ZSTD_hashPtr(ip - 2, hBitsL, 8)] = hashSmall[ZSTD_hashPtr(ip - 2, hBitsS, mls)] = (U32)(ip - 2 - base); + + /* check immediate repcode */ + while ((ip <= ilimit) && ((offset_2 > 0) & (ZSTD_read32(ip) == ZSTD_read32(ip - offset_2)))) { + /* store sequence */ + size_t const rLength = ZSTD_count(ip + 4, ip + 4 - offset_2, iend) + 4; + { + U32 const tmpOff = offset_2; + offset_2 = offset_1; + offset_1 = tmpOff; + } /* swap offset_2 <=> offset_1 */ + hashSmall[ZSTD_hashPtr(ip, hBitsS, mls)] = (U32)(ip - base); + hashLong[ZSTD_hashPtr(ip, hBitsL, 8)] = (U32)(ip - base); + ZSTD_storeSeq(seqStorePtr, 0, anchor, 0, rLength - MINMATCH); + ip += rLength; + anchor = ip; + continue; /* faster when present ... (?) */ + } + } + } + + /* save reps for next block */ + cctx->repToConfirm[0] = offset_1 ? offset_1 : offsetSaved; + cctx->repToConfirm[1] = offset_2 ? offset_2 : offsetSaved; + + /* Last Literals */ + { + size_t const lastLLSize = iend - anchor; + memcpy(seqStorePtr->lit, anchor, lastLLSize); + seqStorePtr->lit += lastLLSize; + } +} + +static void ZSTD_compressBlock_doubleFast(ZSTD_CCtx *ctx, const void *src, size_t srcSize) +{ + const U32 mls = ctx->params.cParams.searchLength; + switch (mls) { + default: /* includes case 3 */ + case 4: ZSTD_compressBlock_doubleFast_generic(ctx, src, srcSize, 4); return; + case 5: ZSTD_compressBlock_doubleFast_generic(ctx, src, srcSize, 5); return; + case 6: ZSTD_compressBlock_doubleFast_generic(ctx, src, srcSize, 6); return; + case 7: ZSTD_compressBlock_doubleFast_generic(ctx, src, srcSize, 7); return; + } +} + +static void ZSTD_compressBlock_doubleFast_extDict_generic(ZSTD_CCtx *ctx, const void *src, size_t srcSize, const U32 mls) +{ + U32 *const hashLong = ctx->hashTable; + U32 const hBitsL = ctx->params.cParams.hashLog; + U32 *const hashSmall = ctx->chainTable; + U32 const hBitsS = ctx->params.cParams.chainLog; + seqStore_t *seqStorePtr = &(ctx->seqStore); + const BYTE *const base = ctx->base; + const BYTE *const dictBase = ctx->dictBase; + const BYTE *const istart = (const BYTE *)src; + const BYTE *ip = istart; + const BYTE *anchor = istart; + const U32 lowestIndex = ctx->lowLimit; + const BYTE *const dictStart = dictBase + lowestIndex; + const U32 dictLimit = ctx->dictLimit; + const BYTE *const lowPrefixPtr = base + dictLimit; + const BYTE *const dictEnd = dictBase + dictLimit; + const BYTE *const iend = istart + srcSize; + const BYTE *const ilimit = iend - 8; + U32 offset_1 = ctx->rep[0], offset_2 = ctx->rep[1]; + + /* Search Loop */ + while (ip < ilimit) { /* < instead of <=, because (ip+1) */ + const size_t hSmall = ZSTD_hashPtr(ip, hBitsS, mls); + const U32 matchIndex = hashSmall[hSmall]; + const BYTE *matchBase = matchIndex < dictLimit ? dictBase : base; + const BYTE *match = matchBase + matchIndex; + + const size_t hLong = ZSTD_hashPtr(ip, hBitsL, 8); + const U32 matchLongIndex = hashLong[hLong]; + const BYTE *matchLongBase = matchLongIndex < dictLimit ? dictBase : base; + const BYTE *matchLong = matchLongBase + matchLongIndex; + + const U32 curr = (U32)(ip - base); + const U32 repIndex = curr + 1 - offset_1; /* offset_1 expected <= curr +1 */ + const BYTE *repBase = repIndex < dictLimit ? dictBase : base; + const BYTE *repMatch = repBase + repIndex; + size_t mLength; + hashSmall[hSmall] = hashLong[hLong] = curr; /* update hash table */ + + if ((((U32)((dictLimit - 1) - repIndex) >= 3) /* intentional underflow */ & (repIndex > lowestIndex)) && + (ZSTD_read32(repMatch) == ZSTD_read32(ip + 1))) { + const BYTE *repMatchEnd = repIndex < dictLimit ? dictEnd : iend; + mLength = ZSTD_count_2segments(ip + 1 + 4, repMatch + 4, iend, repMatchEnd, lowPrefixPtr) + 4; + ip++; + ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, 0, mLength - MINMATCH); + } else { + if ((matchLongIndex > lowestIndex) && (ZSTD_read64(matchLong) == ZSTD_read64(ip))) { + const BYTE *matchEnd = matchLongIndex < dictLimit ? dictEnd : iend; + const BYTE *lowMatchPtr = matchLongIndex < dictLimit ? dictStart : lowPrefixPtr; + U32 offset; + mLength = ZSTD_count_2segments(ip + 8, matchLong + 8, iend, matchEnd, lowPrefixPtr) + 8; + offset = curr - matchLongIndex; + while (((ip > anchor) & (matchLong > lowMatchPtr)) && (ip[-1] == matchLong[-1])) { + ip--; + matchLong--; + mLength++; + } /* catch up */ + offset_2 = offset_1; + offset_1 = offset; + ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, offset + ZSTD_REP_MOVE, mLength - MINMATCH); + + } else if ((matchIndex > lowestIndex) && (ZSTD_read32(match) == ZSTD_read32(ip))) { + size_t const h3 = ZSTD_hashPtr(ip + 1, hBitsL, 8); + U32 const matchIndex3 = hashLong[h3]; + const BYTE *const match3Base = matchIndex3 < dictLimit ? dictBase : base; + const BYTE *match3 = match3Base + matchIndex3; + U32 offset; + hashLong[h3] = curr + 1; + if ((matchIndex3 > lowestIndex) && (ZSTD_read64(match3) == ZSTD_read64(ip + 1))) { + const BYTE *matchEnd = matchIndex3 < dictLimit ? dictEnd : iend; + const BYTE *lowMatchPtr = matchIndex3 < dictLimit ? dictStart : lowPrefixPtr; + mLength = ZSTD_count_2segments(ip + 9, match3 + 8, iend, matchEnd, lowPrefixPtr) + 8; + ip++; + offset = curr + 1 - matchIndex3; + while (((ip > anchor) & (match3 > lowMatchPtr)) && (ip[-1] == match3[-1])) { + ip--; + match3--; + mLength++; + } /* catch up */ + } else { + const BYTE *matchEnd = matchIndex < dictLimit ? dictEnd : iend; + const BYTE *lowMatchPtr = matchIndex < dictLimit ? dictStart : lowPrefixPtr; + mLength = ZSTD_count_2segments(ip + 4, match + 4, iend, matchEnd, lowPrefixPtr) + 4; + offset = curr - matchIndex; + while (((ip > anchor) & (match > lowMatchPtr)) && (ip[-1] == match[-1])) { + ip--; + match--; + mLength++; + } /* catch up */ + } + offset_2 = offset_1; + offset_1 = offset; + ZSTD_storeSeq(seqStorePtr, ip - anchor, anchor, offset + ZSTD_REP_MOVE, mLength - MINMATCH); + + } else { + ip += ((ip - anchor) >> g_searchStrength) + 1; + continue; + } + } + + /* found a match : store it */ + ip += mLength; + anchor = ip; + + if (ip <= ilimit) { + /* Fill Table */ + hashSmall[ZSTD_hashPtr(base + curr + 2, hBitsS, mls)] = curr + 2; + hashLong[ZSTD_hashPtr(base + curr + 2, hBitsL, 8)] = curr + 2; + hashSmall[ZSTD_hashPtr(ip - 2, hBitsS, mls)] = (U32)(ip - 2 - base); + hashLong[ZSTD_hashPtr(ip - 2, hBitsL, 8)] = (U32)(ip - 2 - base); + /* check immediate repcode */ + while (ip <= ilimit) { + U32 const curr2 = (U32)(ip - base); + U32 const repIndex2 = curr2 - offset_2; + const BYTE *repMatch2 = repIndex2 < dictLimit ? dictBase + repIndex2 : base + repIndex2; + if ((((U32)((dictLimit - 1) - repIndex2) >= 3) & (repIndex2 > lowestIndex)) /* intentional overflow */ + && (ZSTD_read32(repMatch2) == ZSTD_read32(ip))) { + const BYTE *const repEnd2 = repIndex2 < dictLimit ? dictEnd : iend; + size_t const repLength2 = + ZSTD_count_2segments(ip + EQUAL_READ32, repMatch2 + EQUAL_READ32, iend, repEnd2, lowPrefixPtr) + EQUAL_READ32; + U32 tmpOffset = offset_2; + offset_2 = offset_1; + offset_1 = tmpOffset; /* swap offset_2 <=> offset_1 */ + ZSTD_storeSeq(seqStorePtr, 0, anchor, 0, repLength2 - MINMATCH); + hashSmall[ZSTD_hashPtr(ip, hBitsS, mls)] = curr2; + hashLong[ZSTD_hashPtr(ip, hBitsL, 8)] = curr2; + ip += repLength2; + anchor = ip; + continue; + } + break; + } + } + } + + /* save reps for next block */ + ctx->repToConfirm[0] = offset_1; + ctx->repToConfirm[1] = offset_2; + + /* Last Literals */ + { + size_t const lastLLSize = iend - anchor; + memcpy(seqStorePtr->lit, anchor, lastLLSize); + seqStorePtr->lit += lastLLSize; + } +} + +static void ZSTD_compressBlock_doubleFast_extDict(ZSTD_CCtx *ctx, const void *src, size_t srcSize) +{ + U32 const mls = ctx->params.cParams.searchLength; + switch (mls) { + default: /* includes case 3 */ + case 4: ZSTD_compressBlock_doubleFast_extDict_generic(ctx, src, srcSize, 4); return; + case 5: ZSTD_compressBlock_doubleFast_extDict_generic(ctx, src, srcSize, 5); return; + case 6: ZSTD_compressBlock_doubleFast_extDict_generic(ctx, src, srcSize, 6); return; + case 7: ZSTD_compressBlock_doubleFast_extDict_generic(ctx, src, srcSize, 7); return; + } +} + +/*-************************************* +* Binary Tree search +***************************************/ +/** ZSTD_insertBt1() : add one or multiple positions to tree. +* ip : assumed <= iend-8 . +* @return : nb of positions added */ +static U32 ZSTD_insertBt1(ZSTD_CCtx *zc, const BYTE *const ip, const U32 mls, const BYTE *const iend, U32 nbCompares, U32 extDict) +{ + U32 *const hashTable = zc->hashTable; + U32 const hashLog = zc->params.cParams.hashLog; + size_t const h = ZSTD_hashPtr(ip, hashLog, mls); + U32 *const bt = zc->chainTable; + U32 const btLog = zc->params.cParams.chainLog - 1; + U32 const btMask = (1 << btLog) - 1; + U32 matchIndex = hashTable[h]; + size_t commonLengthSmaller = 0, commonLengthLarger = 0; + const BYTE *const base = zc->base; + const BYTE *const dictBase = zc->dictBase; + const U32 dictLimit = zc->dictLimit; + const BYTE *const dictEnd = dictBase + dictLimit; + const BYTE *const prefixStart = base + dictLimit; + const BYTE *match; + const U32 curr = (U32)(ip - base); + const U32 btLow = btMask >= curr ? 0 : curr - btMask; + U32 *smallerPtr = bt + 2 * (curr & btMask); + U32 *largerPtr = smallerPtr + 1; + U32 dummy32; /* to be nullified at the end */ + U32 const windowLow = zc->lowLimit; + U32 matchEndIdx = curr + 8; + size_t bestLength = 8; + + hashTable[h] = curr; /* Update Hash Table */ + + while (nbCompares-- && (matchIndex > windowLow)) { + U32 *const nextPtr = bt + 2 * (matchIndex & btMask); + size_t matchLength = MIN(commonLengthSmaller, commonLengthLarger); /* guaranteed minimum nb of common bytes */ + + if ((!extDict) || (matchIndex + matchLength >= dictLimit)) { + match = base + matchIndex; + if (match[matchLength] == ip[matchLength]) + matchLength += ZSTD_count(ip + matchLength + 1, match + matchLength + 1, iend) + 1; + } else { + match = dictBase + matchIndex; + matchLength += ZSTD_count_2segments(ip + matchLength, match + matchLength, iend, dictEnd, prefixStart); + if (matchIndex + matchLength >= dictLimit) + match = base + matchIndex; /* to prepare for next usage of match[matchLength] */ + } + + if (matchLength > bestLength) { + bestLength = matchLength; + if (matchLength > matchEndIdx - matchIndex) + matchEndIdx = matchIndex + (U32)matchLength; + } + + if (ip + matchLength == iend) /* equal : no way to know if inf or sup */ + break; /* drop , to guarantee consistency ; miss a bit of compression, but other solutions can corrupt the tree */ + + if (match[matchLength] < ip[matchLength]) { /* necessarily within correct buffer */ + /* match is smaller than curr */ + *smallerPtr = matchIndex; /* update smaller idx */ + commonLengthSmaller = matchLength; /* all smaller will now have at least this guaranteed common length */ + if (matchIndex <= btLow) { + smallerPtr = &dummy32; + break; + } /* beyond tree size, stop the search */ + smallerPtr = nextPtr + 1; /* new "smaller" => larger of match */ + matchIndex = nextPtr[1]; /* new matchIndex larger than previous (closer to curr) */ + } else { + /* match is larger than curr */ + *largerPtr = matchIndex; + commonLengthLarger = matchLength; + if (matchIndex <= btLow) { + largerPtr = &dummy32; + break; + } /* beyond tree size, stop the search */ + largerPtr = nextPtr; + matchIndex = nextPtr[0]; + } + } + + *smallerPtr = *largerPtr = 0; + if (bestLength > 384) + return MIN(192, (U32)(bestLength - 384)); /* speed optimization */ + if (matchEndIdx > curr + 8) + return matchEndIdx - curr - 8; + return 1; +} + +static size_t ZSTD_insertBtAndFindBestMatch(ZSTD_CCtx *zc, const BYTE *const ip, const BYTE *const iend, size_t *offsetPtr, U32 nbCompares, const U32 mls, + U32 extDict) +{ + U32 *const hashTable = zc->hashTable; + U32 const hashLog = zc->params.cParams.hashLog; + size_t const h = ZSTD_hashPtr(ip, hashLog, mls); + U32 *const bt = zc->chainTable; + U32 const btLog = zc->params.cParams.chainLog - 1; + U32 const btMask = (1 << btLog) - 1; + U32 matchIndex = hashTable[h]; + size_t commonLengthSmaller = 0, commonLengthLarger = 0; + const BYTE *const base = zc->base; + const BYTE *const dictBase = zc->dictBase; + const U32 dictLimit = zc->dictLimit; + const BYTE *const dictEnd = dictBase + dictLimit; + const BYTE *const prefixStart = base + dictLimit; + const U32 curr = (U32)(ip - base); + const U32 btLow = btMask >= curr ? 0 : curr - btMask; + const U32 windowLow = zc->lowLimit; + U32 *smallerPtr = bt + 2 * (curr & btMask); + U32 *largerPtr = bt + 2 * (curr & btMask) + 1; + U32 matchEndIdx = curr + 8; + U32 dummy32; /* to be nullified at the end */ + size_t bestLength = 0; + + hashTable[h] = curr; /* Update Hash Table */ + + while (nbCompares-- && (matchIndex > windowLow)) { + U32 *const nextPtr = bt + 2 * (matchIndex & btMask); + size_t matchLength = MIN(commonLengthSmaller, commonLengthLarger); /* guaranteed minimum nb of common bytes */ + const BYTE *match; + + if ((!extDict) || (matchIndex + matchLength >= dictLimit)) { + match = base + matchIndex; + if (match[matchLength] == ip[matchLength]) + matchLength += ZSTD_count(ip + matchLength + 1, match + matchLength + 1, iend) + 1; + } else { + match = dictBase + matchIndex; + matchLength += ZSTD_count_2segments(ip + matchLength, match + matchLength, iend, dictEnd, prefixStart); + if (matchIndex + matchLength >= dictLimit) + match = base + matchIndex; /* to prepare for next usage of match[matchLength] */ + } + + if (matchLength > bestLength) { + if (matchLength > matchEndIdx - matchIndex) + matchEndIdx = matchIndex + (U32)matchLength; + if ((4 * (int)(matchLength - bestLength)) > (int)(ZSTD_highbit32(curr - matchIndex + 1) - ZSTD_highbit32((U32)offsetPtr[0] + 1))) + bestLength = matchLength, *offsetPtr = ZSTD_REP_MOVE + curr - matchIndex; + if (ip + matchLength == iend) /* equal : no way to know if inf or sup */ + break; /* drop, to guarantee consistency (miss a little bit of compression) */ + } + + if (match[matchLength] < ip[matchLength]) { + /* match is smaller than curr */ + *smallerPtr = matchIndex; /* update smaller idx */ + commonLengthSmaller = matchLength; /* all smaller will now have at least this guaranteed common length */ + if (matchIndex <= btLow) { + smallerPtr = &dummy32; + break; + } /* beyond tree size, stop the search */ + smallerPtr = nextPtr + 1; /* new "smaller" => larger of match */ + matchIndex = nextPtr[1]; /* new matchIndex larger than previous (closer to curr) */ + } else { + /* match is larger than curr */ + *largerPtr = matchIndex; + commonLengthLarger = matchLength; + if (matchIndex <= btLow) { + largerPtr = &dummy32; + break; + } /* beyond tree size, stop the search */ + largerPtr = nextPtr; + matchIndex = nextPtr[0]; + } + } + + *smallerPtr = *largerPtr = 0; + + zc->nextToUpdate = (matchEndIdx > curr + 8) ? matchEndIdx - 8 : curr + 1; + return bestLength; +} + +static void ZSTD_updateTree(ZSTD_CCtx *zc, const BYTE *const ip, const BYTE *const iend, const U32 nbCompares, const U32 mls) +{ + const BYTE *const base = zc->base; + const U32 target = (U32)(ip - base); + U32 idx = zc->nextToUpdate; + + while (idx < target) + idx += ZSTD_insertBt1(zc, base + idx, mls, iend, nbCompares, 0); +} + +/** ZSTD_BtFindBestMatch() : Tree updater, providing best match */ +static size_t ZSTD_BtFindBestMatch(ZSTD_CCtx *zc, const BYTE *const ip, const BYTE *const iLimit, size_t *offsetPtr, const U32 maxNbAttempts, const U32 mls) +{ + if (ip < zc->base + zc->nextToUpdate) + return 0; /* skipped area */ + ZSTD_updateTree(zc, ip, iLimit, maxNbAttempts, mls); + return ZSTD_insertBtAndFindBestMatch(zc, ip, iLimit, offsetPtr, maxNbAttempts, mls, 0); +} + +static size_t ZSTD_BtFindBestMatch_selectMLS(ZSTD_CCtx *zc, /* Index table will be updated */ + const BYTE *ip, const BYTE *const iLimit, size_t *offsetPtr, const U32 maxNbAttempts, const U32 matchLengthSearch) +{ + switch (matchLengthSearch) { + default: /* includes case 3 */ + case 4: return ZSTD_BtFindBestMatch(zc, ip, iLimit, offsetPtr, maxNbAttempts, 4); + case 5: return ZSTD_BtFindBestMatch(zc, ip, iLimit, offsetPtr, maxNbAttempts, 5); + case 7: + case 6: return ZSTD_BtFindBestMatch(zc, ip, iLimit, offsetPtr, maxNbAttempts, 6); + } +} + +static void ZSTD_updateTree_extDict(ZSTD_CCtx *zc, const BYTE *const ip, const BYTE *const iend, const U32 nbCompares, const U32 mls) +{ + const BYTE *const base = zc->base; + const U32 target = (U32)(ip - base); + U32 idx = zc->nextToUpdate; + + while (idx < target) + idx += ZSTD_insertBt1(zc, base + idx, mls, iend, nbCompares, 1); +} + +/** Tree updater, providing best match */ +static size_t ZSTD_BtFindBestMatch_extDict(ZSTD_CCtx *zc, const BYTE *const ip, const BYTE *const iLimit, size_t *offsetPtr, const U32 maxNbAttempts, + const U32 mls) +{ + if (ip < zc->base + zc->nextToUpdate) + return 0; /* skipped area */ + ZSTD_updateTree_extDict(zc, ip, iLimit, maxNbAttempts, mls); + return ZSTD_insertBtAndFindBestMatch(zc, ip, iLimit, offsetPtr, maxNbAttempts, mls, 1); +} + +static size_t ZSTD_BtFindBestMatch_selectMLS_extDict(ZSTD_CCtx *zc, /* Index table will be updated */ + const BYTE *ip, const BYTE *const iLimit, size_t *offsetPtr, const U32 maxNbAttempts, + const U32 matchLengthSearch) +{ + switch (matchLengthSearch) { + default: /* includes case 3 */ + case 4: return ZSTD_BtFindBestMatch_extDict(zc, ip, iLimit, offsetPtr, maxNbAttempts, 4); + case 5: return ZSTD_BtFindBestMatch_extDict(zc, ip, iLimit, offsetPtr, maxNbAttempts, 5); + case 7: + case 6: return ZSTD_BtFindBestMatch_extDict(zc, ip, iLimit, offsetPtr, maxNbAttempts, 6); + } +} + +/* ********************************* +* Hash Chain +***********************************/ +#define NEXT_IN_CHAIN(d, mask) chainTable[(d)&mask] + +/* Update chains up to ip (excluded) + Assumption : always within prefix (i.e. not within extDict) */ +FORCE_INLINE +U32 ZSTD_insertAndFindFirstIndex(ZSTD_CCtx *zc, const BYTE *ip, U32 mls) +{ + U32 *const hashTable = zc->hashTable; + const U32 hashLog = zc->params.cParams.hashLog; + U32 *const chainTable = zc->chainTable; + const U32 chainMask = (1 << zc->params.cParams.chainLog) - 1; + const BYTE *const base = zc->base; + const U32 target = (U32)(ip - base); + U32 idx = zc->nextToUpdate; + + while (idx < target) { /* catch up */ + size_t const h = ZSTD_hashPtr(base + idx, hashLog, mls); + NEXT_IN_CHAIN(idx, chainMask) = hashTable[h]; + hashTable[h] = idx; + idx++; + } + + zc->nextToUpdate = target; + return hashTable[ZSTD_hashPtr(ip, hashLog, mls)]; +} + +/* inlining is important to hardwire a hot branch (template emulation) */ +FORCE_INLINE +size_t ZSTD_HcFindBestMatch_generic(ZSTD_CCtx *zc, /* Index table will be updated */ + const BYTE *const ip, const BYTE *const iLimit, size_t *offsetPtr, const U32 maxNbAttempts, const U32 mls, + const U32 extDict) +{ + U32 *const chainTable = zc->chainTable; + const U32 chainSize = (1 << zc->params.cParams.chainLog); + const U32 chainMask = chainSize - 1; + const BYTE *const base = zc->base; + const BYTE *const dictBase = zc->dictBase; + const U32 dictLimit = zc->dictLimit; + const BYTE *const prefixStart = base + dictLimit; + const BYTE *const dictEnd = dictBase + dictLimit; + const U32 lowLimit = zc->lowLimit; + const U32 curr = (U32)(ip - base); + const U32 minChain = curr > chainSize ? curr - chainSize : 0; + int nbAttempts = maxNbAttempts; + size_t ml = EQUAL_READ32 - 1; + + /* HC4 match finder */ + U32 matchIndex = ZSTD_insertAndFindFirstIndex(zc, ip, mls); + + for (; (matchIndex > lowLimit) & (nbAttempts > 0); nbAttempts--) { + const BYTE *match; + size_t currMl = 0; + if ((!extDict) || matchIndex >= dictLimit) { + match = base + matchIndex; + if (match[ml] == ip[ml]) /* potentially better */ + currMl = ZSTD_count(ip, match, iLimit); + } else { + match = dictBase + matchIndex; + if (ZSTD_read32(match) == ZSTD_read32(ip)) /* assumption : matchIndex <= dictLimit-4 (by table construction) */ + currMl = ZSTD_count_2segments(ip + EQUAL_READ32, match + EQUAL_READ32, iLimit, dictEnd, prefixStart) + EQUAL_READ32; + } + + /* save best solution */ + if (currMl > ml) { + ml = currMl; + *offsetPtr = curr - matchIndex + ZSTD_REP_MOVE; + if (ip + currMl == iLimit) + break; /* best possible, and avoid read overflow*/ + } + + if (matchIndex <= minChain) + break; + matchIndex = NEXT_IN_CHAIN(matchIndex, chainMask); + } + + return ml; +} + +FORCE_INLINE size_t ZSTD_HcFindBestMatch_selectMLS(ZSTD_CCtx *zc, const BYTE *ip, const BYTE *const iLimit, size_t *offsetPtr, const U32 maxNbAttempts, + const U32 matchLengthSearch) +{ + switch (matchLengthSearch) { + default: /* includes case 3 */ + case 4: return ZSTD_HcFindBestMatch_generic(zc, ip, iLimit, offsetPtr, maxNbAttempts, 4, 0); + case 5: return ZSTD_HcFindBestMatch_generic(zc, ip, iLimit, offsetPtr, maxNbAttempts, 5, 0); + case 7: + case 6: return ZSTD_HcFindBestMatch_generic(zc, ip, iLimit, offsetPtr, maxNbAttempts, 6, 0); + } +} + +FORCE_INLINE size_t ZSTD_HcFindBestMatch_extDict_selectMLS(ZSTD_CCtx *zc, const BYTE *ip, const BYTE *const iLimit, size_t *offsetPtr, const U32 maxNbAttempts, + const U32 matchLengthSearch) +{ + switch (matchLengthSearch) { + default: /* includes case 3 */ + case 4: return ZSTD_HcFindBestMatch_generic(zc, ip, iLimit, offsetPtr, maxNbAttempts, 4, 1); + case 5: return ZSTD_HcFindBestMatch_generic(zc, ip, iLimit, offsetPtr, maxNbAttempts, 5, 1); + case 7: + case 6: return ZSTD_HcFindBestMatch_generic(zc, ip, iLimit, offsetPtr, maxNbAttempts, 6, 1); + } +} + +/* ******************************* +* Common parser - lazy strategy +*********************************/ +FORCE_INLINE +void ZSTD_compressBlock_lazy_generic(ZSTD_CCtx *ctx, const void *src, size_t srcSize, const U32 searchMethod, const U32 depth) +{ + seqStore_t *seqStorePtr = &(ctx->seqStore); + const BYTE *const istart = (const BYTE *)src; + const BYTE *ip = istart; + const BYTE *anchor = istart; + const BYTE *const iend = istart + srcSize; + const BYTE *const ilimit = iend - 8; + const BYTE *const base = ctx->base + ctx->dictLimit; + + U32 const maxSearches = 1 << ctx->params.cParams.searchLog; + U32 const mls = ctx->params.cParams.searchLength; + + typedef size_t (*searchMax_f)(ZSTD_CCtx * zc, const BYTE *ip, const BYTE *iLimit, size_t *offsetPtr, U32 maxNbAttempts, U32 matchLengthSearch); + searchMax_f const searchMax = searchMethod ? ZSTD_BtFindBestMatch_selectMLS : ZSTD_HcFindBestMatch_selectMLS; + U32 offset_1 = ctx->rep[0], offset_2 = ctx->rep[1], savedOffset = 0; + + /* init */ + ip += (ip == base); + ctx->nextToUpdate3 = ctx->nextToUpdate; + { + U32 const maxRep = (U32)(ip - base); + if (offset_2 > maxRep) + savedOffset = offset_2, offset_2 = 0; + if (offset_1 > maxRep) + savedOffset = offset_1, offset_1 = 0; + } + + /* Match Loop */ + while (ip < ilimit) { + size_t matchLength = 0; + size_t offset = 0; + const BYTE *start = ip + 1; + + /* check repCode */ + if ((offset_1 > 0) & (ZSTD_read32(ip + 1) == ZSTD_read32(ip + 1 - offset_1))) { + /* repcode : we take it */ + matchLength = ZSTD_count(ip + 1 + EQUAL_READ32, ip + 1 + EQUAL_READ32 - offset_1, iend) + EQUAL_READ32; + if (depth == 0) + goto _storeSequence; + } + + /* first search (depth 0) */ + { + size_t offsetFound = 99999999; + size_t const ml2 = searchMax(ctx, ip, iend, &offsetFound, maxSearches, mls); + if (ml2 > matchLength) + matchLength = ml2, start = ip, offset = offsetFound; + } + + if (matchLength < EQUAL_READ32) { + ip += ((ip - anchor) >> g_searchStrength) + 1; /* jump faster over incompressible sections */ + continue; + } + + /* let's try to find a better solution */ + if (depth >= 1) + while (ip < ilimit) { + ip++; + if ((offset) && ((offset_1 > 0) & (ZSTD_read32(ip) == ZSTD_read32(ip - offset_1)))) { + size_t const mlRep = ZSTD_count(ip + EQUAL_READ32, ip + EQUAL_READ32 - offset_1, iend) + EQUAL_READ32; + int const gain2 = (int)(mlRep * 3); + int const gain1 = (int)(matchLength * 3 - ZSTD_highbit32((U32)offset + 1) + 1); + if ((mlRep >= EQUAL_READ32) && (gain2 > gain1)) + matchLength = mlRep, offset = 0, start = ip; + } + { + size_t offset2 = 99999999; + size_t const ml2 = searchMax(ctx, ip, iend, &offset2, maxSearches, mls); + int const gain2 = (int)(ml2 * 4 - ZSTD_highbit32((U32)offset2 + 1)); /* raw approx */ + int const gain1 = (int)(matchLength * 4 - ZSTD_highbit32((U32)offset + 1) + 4); + if ((ml2 >= EQUAL_READ32) && (gain2 > gain1)) { + matchLength = ml2, offset = offset2, start = ip; + continue; /* search a better one */ + } + } + + /* let's find an even better one */ + if ((depth == 2) && (ip < ilimit)) { + ip++; + if ((offset) && ((offset_1 > 0) & (ZSTD_read32(ip) == ZSTD_read32(ip - offset_1)))) { + size_t const ml2 = ZSTD_count(ip + EQUAL_READ32, ip + EQUAL_READ32 - offset_1, iend) + EQUAL_READ32; + int const gain2 = (int)(ml2 * 4); + int const gain1 = (int)(matchLength * 4 - ZSTD_highbit32((U32)offset + 1) + 1); + if ((ml2 >= EQUAL_READ32) && (gain2 > gain1)) + matchLength = ml2, offset = 0, start = ip; + } + { + size_t offset2 = 99999999; + size_t const ml2 = searchMax(ctx, ip, iend, &offset2, maxSearches, mls); + int const gain2 = (int)(ml2 * 4 - ZSTD_highbit32((U32)offset2 + 1)); /* raw approx */ + int const gain1 = (int)(matchLength * 4 - ZSTD_highbit32((U32)offset + 1) + 7); + if ((ml2 >= EQUAL_READ32) && (gain2 > gain1)) { + matchLength = ml2, offset = offset2, start = ip; + continue; + } + } + } + break; /* nothing found : store previous solution */ + } + + /* NOTE: + * start[-offset+ZSTD_REP_MOVE-1] is undefined behavior. + * (-offset+ZSTD_REP_MOVE-1) is unsigned, and is added to start, which + * overflows the pointer, which is undefined behavior. + */ + /* catch up */ + if (offset) { + while ((start > anchor) && (start > base + offset - ZSTD_REP_MOVE) && + (start[-1] == (start-offset+ZSTD_REP_MOVE)[-1])) /* only search for offset within prefix */ + { + start--; + matchLength++; + } + offset_2 = offset_1; + offset_1 = (U32)(offset - ZSTD_REP_MOVE); + } + + /* store sequence */ +_storeSequence: + { + size_t const litLength = start - anchor; + ZSTD_storeSeq(seqStorePtr, litLength, anchor, (U32)offset, matchLength - MINMATCH); + anchor = ip = start + matchLength; + } + + /* check immediate repcode */ + while ((ip <= ilimit) && ((offset_2 > 0) & (ZSTD_read32(ip) == ZSTD_read32(ip - offset_2)))) { + /* store sequence */ + matchLength = ZSTD_count(ip + EQUAL_READ32, ip + EQUAL_READ32 - offset_2, iend) + EQUAL_READ32; + offset = offset_2; + offset_2 = offset_1; + offset_1 = (U32)offset; /* swap repcodes */ + ZSTD_storeSeq(seqStorePtr, 0, anchor, 0, matchLength - MINMATCH); + ip += matchLength; + anchor = ip; + continue; /* faster when present ... (?) */ + } + } + + /* Save reps for next block */ + ctx->repToConfirm[0] = offset_1 ? offset_1 : savedOffset; + ctx->repToConfirm[1] = offset_2 ? offset_2 : savedOffset; + + /* Last Literals */ + { + size_t const lastLLSize = iend - anchor; + memcpy(seqStorePtr->lit, anchor, lastLLSize); + seqStorePtr->lit += lastLLSize; + } +} + +static void ZSTD_compressBlock_btlazy2(ZSTD_CCtx *ctx, const void *src, size_t srcSize) { ZSTD_compressBlock_lazy_generic(ctx, src, srcSize, 1, 2); } + +static void ZSTD_compressBlock_lazy2(ZSTD_CCtx *ctx, const void *src, size_t srcSize) { ZSTD_compressBlock_lazy_generic(ctx, src, srcSize, 0, 2); } + +static void ZSTD_compressBlock_lazy(ZSTD_CCtx *ctx, const void *src, size_t srcSize) { ZSTD_compressBlock_lazy_generic(ctx, src, srcSize, 0, 1); } + +static void ZSTD_compressBlock_greedy(ZSTD_CCtx *ctx, const void *src, size_t srcSize) { ZSTD_compressBlock_lazy_generic(ctx, src, srcSize, 0, 0); } + +FORCE_INLINE +void ZSTD_compressBlock_lazy_extDict_generic(ZSTD_CCtx *ctx, const void *src, size_t srcSize, const U32 searchMethod, const U32 depth) +{ + seqStore_t *seqStorePtr = &(ctx->seqStore); + const BYTE *const istart = (const BYTE *)src; + const BYTE *ip = istart; + const BYTE *anchor = istart; + const BYTE *const iend = istart + srcSize; + const BYTE *const ilimit = iend - 8; + const BYTE *const base = ctx->base; + const U32 dictLimit = ctx->dictLimit; + const U32 lowestIndex = ctx->lowLimit; + const BYTE *const prefixStart = base + dictLimit; + const BYTE *const dictBase = ctx->dictBase; + const BYTE *const dictEnd = dictBase + dictLimit; + const BYTE *const dictStart = dictBase + ctx->lowLimit; + + const U32 maxSearches = 1 << ctx->params.cParams.searchLog; + const U32 mls = ctx->params.cParams.searchLength; + + typedef size_t (*searchMax_f)(ZSTD_CCtx * zc, const BYTE *ip, const BYTE *iLimit, size_t *offsetPtr, U32 maxNbAttempts, U32 matchLengthSearch); + searchMax_f searchMax = searchMethod ? ZSTD_BtFindBestMatch_selectMLS_extDict : ZSTD_HcFindBestMatch_extDict_selectMLS; + + U32 offset_1 = ctx->rep[0], offset_2 = ctx->rep[1]; + + /* init */ + ctx->nextToUpdate3 = ctx->nextToUpdate; + ip += (ip == prefixStart); + + /* Match Loop */ + while (ip < ilimit) { + size_t matchLength = 0; + size_t offset = 0; + const BYTE *start = ip + 1; + U32 curr = (U32)(ip - base); + + /* check repCode */ + { + const U32 repIndex = (U32)(curr + 1 - offset_1); + const BYTE *const repBase = repIndex < dictLimit ? dictBase : base; + const BYTE *const repMatch = repBase + repIndex; + if (((U32)((dictLimit - 1) - repIndex) >= 3) & (repIndex > lowestIndex)) /* intentional overflow */ + if (ZSTD_read32(ip + 1) == ZSTD_read32(repMatch)) { + /* repcode detected we should take it */ + const BYTE *const repEnd = repIndex < dictLimit ? dictEnd : iend; + matchLength = + ZSTD_count_2segments(ip + 1 + EQUAL_READ32, repMatch + EQUAL_READ32, iend, repEnd, prefixStart) + EQUAL_READ32; + if (depth == 0) + goto _storeSequence; + } + } + + /* first search (depth 0) */ + { + size_t offsetFound = 99999999; + size_t const ml2 = searchMax(ctx, ip, iend, &offsetFound, maxSearches, mls); + if (ml2 > matchLength) + matchLength = ml2, start = ip, offset = offsetFound; + } + + if (matchLength < EQUAL_READ32) { + ip += ((ip - anchor) >> g_searchStrength) + 1; /* jump faster over incompressible sections */ + continue; + } + + /* let's try to find a better solution */ + if (depth >= 1) + while (ip < ilimit) { + ip++; + curr++; + /* check repCode */ + if (offset) { + const U32 repIndex = (U32)(curr - offset_1); + const BYTE *const repBase = repIndex < dictLimit ? dictBase : base; + const BYTE *const repMatch = repBase + repIndex; + if (((U32)((dictLimit - 1) - repIndex) >= 3) & (repIndex > lowestIndex)) /* intentional overflow */ + if (ZSTD_read32(ip) == ZSTD_read32(repMatch)) { + /* repcode detected */ + const BYTE *const repEnd = repIndex < dictLimit ? dictEnd : iend; + size_t const repLength = + ZSTD_count_2segments(ip + EQUAL_READ32, repMatch + EQUAL_READ32, iend, repEnd, prefixStart) + + EQUAL_READ32; + int const gain2 = (int)(repLength * 3); + int const gain1 = (int)(matchLength * 3 - ZSTD_highbit32((U32)offset + 1) + 1); + if ((repLength >= EQUAL_READ32) && (gain2 > gain1)) + matchLength = repLength, offset = 0, start = ip; + } + } + + /* search match, depth 1 */ + { + size_t offset2 = 99999999; + size_t const ml2 = searchMax(ctx, ip, iend, &offset2, maxSearches, mls); + int const gain2 = (int)(ml2 * 4 - ZSTD_highbit32((U32)offset2 + 1)); /* raw approx */ + int const gain1 = (int)(matchLength * 4 - ZSTD_highbit32((U32)offset + 1) + 4); + if ((ml2 >= EQUAL_READ32) && (gain2 > gain1)) { + matchLength = ml2, offset = offset2, start = ip; + continue; /* search a better one */ + } + } + + /* let's find an even better one */ + if ((depth == 2) && (ip < ilimit)) { + ip++; + curr++; + /* check repCode */ + if (offset) { + const U32 repIndex = (U32)(curr - offset_1); + const BYTE *const repBase = repIndex < dictLimit ? dictBase : base; + const BYTE *const repMatch = repBase + repIndex; + if (((U32)((dictLimit - 1) - repIndex) >= 3) & (repIndex > lowestIndex)) /* intentional overflow */ + if (ZSTD_read32(ip) == ZSTD_read32(repMatch)) { + /* repcode detected */ + const BYTE *const repEnd = repIndex < dictLimit ? dictEnd : iend; + size_t repLength = ZSTD_count_2segments(ip + EQUAL_READ32, repMatch + EQUAL_READ32, iend, + repEnd, prefixStart) + + EQUAL_READ32; + int gain2 = (int)(repLength * 4); + int gain1 = (int)(matchLength * 4 - ZSTD_highbit32((U32)offset + 1) + 1); + if ((repLength >= EQUAL_READ32) && (gain2 > gain1)) + matchLength = repLength, offset = 0, start = ip; + } + } + + /* search match, depth 2 */ + { + size_t offset2 = 99999999; + size_t const ml2 = searchMax(ctx, ip, iend, &offset2, maxSearches, mls); + int const gain2 = (int)(ml2 * 4 - ZSTD_highbit32((U32)offset2 + 1)); /* raw approx */ + int const gain1 = (int)(matchLength * 4 - ZSTD_highbit32((U32)offset + 1) + 7); + if ((ml2 >= EQUAL_READ32) && (gain2 > gain1)) { + matchLength = ml2, offset = offset2, start = ip; + continue; + } + } + } + break; /* nothing found : store previous solution */ + } + + /* catch up */ + if (offset) { + U32 const matchIndex = (U32)((start - base) - (offset - ZSTD_REP_MOVE)); + const BYTE *match = (matchIndex < dictLimit) ? dictBase + matchIndex : base + matchIndex; + const BYTE *const mStart = (matchIndex < dictLimit) ? dictStart : prefixStart; + while ((start > anchor) && (match > mStart) && (start[-1] == match[-1])) { + start--; + match--; + matchLength++; + } /* catch up */ + offset_2 = offset_1; + offset_1 = (U32)(offset - ZSTD_REP_MOVE); + } + + /* store sequence */ + _storeSequence : { + size_t const litLength = start - anchor; + ZSTD_storeSeq(seqStorePtr, litLength, anchor, (U32)offset, matchLength - MINMATCH); + anchor = ip = start + matchLength; + } + + /* check immediate repcode */ + while (ip <= ilimit) { + const U32 repIndex = (U32)((ip - base) - offset_2); + const BYTE *const repBase = repIndex < dictLimit ? dictBase : base; + const BYTE *const repMatch = repBase + repIndex; + if (((U32)((dictLimit - 1) - repIndex) >= 3) & (repIndex > lowestIndex)) /* intentional overflow */ + if (ZSTD_read32(ip) == ZSTD_read32(repMatch)) { + /* repcode detected we should take it */ + const BYTE *const repEnd = repIndex < dictLimit ? dictEnd : iend; + matchLength = + ZSTD_count_2segments(ip + EQUAL_READ32, repMatch + EQUAL_READ32, iend, repEnd, prefixStart) + EQUAL_READ32; + offset = offset_2; + offset_2 = offset_1; + offset_1 = (U32)offset; /* swap offset history */ + ZSTD_storeSeq(seqStorePtr, 0, anchor, 0, matchLength - MINMATCH); + ip += matchLength; + anchor = ip; + continue; /* faster when present ... (?) */ + } + break; + } + } + + /* Save reps for next block */ + ctx->repToConfirm[0] = offset_1; + ctx->repToConfirm[1] = offset_2; + + /* Last Literals */ + { + size_t const lastLLSize = iend - anchor; + memcpy(seqStorePtr->lit, anchor, lastLLSize); + seqStorePtr->lit += lastLLSize; + } +} + +void ZSTD_compressBlock_greedy_extDict(ZSTD_CCtx *ctx, const void *src, size_t srcSize) { ZSTD_compressBlock_lazy_extDict_generic(ctx, src, srcSize, 0, 0); } + +static void ZSTD_compressBlock_lazy_extDict(ZSTD_CCtx *ctx, const void *src, size_t srcSize) +{ + ZSTD_compressBlock_lazy_extDict_generic(ctx, src, srcSize, 0, 1); +} + +static void ZSTD_compressBlock_lazy2_extDict(ZSTD_CCtx *ctx, const void *src, size_t srcSize) +{ + ZSTD_compressBlock_lazy_extDict_generic(ctx, src, srcSize, 0, 2); +} + +static void ZSTD_compressBlock_btlazy2_extDict(ZSTD_CCtx *ctx, const void *src, size_t srcSize) +{ + ZSTD_compressBlock_lazy_extDict_generic(ctx, src, srcSize, 1, 2); +} + +/* The optimal parser */ +#include "zstd_opt.h" + +static void ZSTD_compressBlock_btopt(ZSTD_CCtx *ctx, const void *src, size_t srcSize) +{ +#ifdef ZSTD_OPT_H_91842398743 + ZSTD_compressBlock_opt_generic(ctx, src, srcSize, 0); +#else + (void)ctx; + (void)src; + (void)srcSize; + return; +#endif +} + +static void ZSTD_compressBlock_btopt2(ZSTD_CCtx *ctx, const void *src, size_t srcSize) +{ +#ifdef ZSTD_OPT_H_91842398743 + ZSTD_compressBlock_opt_generic(ctx, src, srcSize, 1); +#else + (void)ctx; + (void)src; + (void)srcSize; + return; +#endif +} + +static void ZSTD_compressBlock_btopt_extDict(ZSTD_CCtx *ctx, const void *src, size_t srcSize) +{ +#ifdef ZSTD_OPT_H_91842398743 + ZSTD_compressBlock_opt_extDict_generic(ctx, src, srcSize, 0); +#else + (void)ctx; + (void)src; + (void)srcSize; + return; +#endif +} + +static void ZSTD_compressBlock_btopt2_extDict(ZSTD_CCtx *ctx, const void *src, size_t srcSize) +{ +#ifdef ZSTD_OPT_H_91842398743 + ZSTD_compressBlock_opt_extDict_generic(ctx, src, srcSize, 1); +#else + (void)ctx; + (void)src; + (void)srcSize; + return; +#endif +} + +typedef void (*ZSTD_blockCompressor)(ZSTD_CCtx *ctx, const void *src, size_t srcSize); + +static ZSTD_blockCompressor ZSTD_selectBlockCompressor(ZSTD_strategy strat, int extDict) +{ + static const ZSTD_blockCompressor blockCompressor[2][8] = { + {ZSTD_compressBlock_fast, ZSTD_compressBlock_doubleFast, ZSTD_compressBlock_greedy, ZSTD_compressBlock_lazy, ZSTD_compressBlock_lazy2, + ZSTD_compressBlock_btlazy2, ZSTD_compressBlock_btopt, ZSTD_compressBlock_btopt2}, + {ZSTD_compressBlock_fast_extDict, ZSTD_compressBlock_doubleFast_extDict, ZSTD_compressBlock_greedy_extDict, ZSTD_compressBlock_lazy_extDict, + ZSTD_compressBlock_lazy2_extDict, ZSTD_compressBlock_btlazy2_extDict, ZSTD_compressBlock_btopt_extDict, ZSTD_compressBlock_btopt2_extDict}}; + + return blockCompressor[extDict][(U32)strat]; +} + +static size_t ZSTD_compressBlock_internal(ZSTD_CCtx *zc, void *dst, size_t dstCapacity, const void *src, size_t srcSize) +{ + ZSTD_blockCompressor const blockCompressor = ZSTD_selectBlockCompressor(zc->params.cParams.strategy, zc->lowLimit < zc->dictLimit); + const BYTE *const base = zc->base; + const BYTE *const istart = (const BYTE *)src; + const U32 curr = (U32)(istart - base); + if (srcSize < MIN_CBLOCK_SIZE + ZSTD_blockHeaderSize + 1) + return 0; /* don't even attempt compression below a certain srcSize */ + ZSTD_resetSeqStore(&(zc->seqStore)); + if (curr > zc->nextToUpdate + 384) + zc->nextToUpdate = curr - MIN(192, (U32)(curr - zc->nextToUpdate - 384)); /* update tree not updated after finding very long rep matches */ + blockCompressor(zc, src, srcSize); + return ZSTD_compressSequences(zc, dst, dstCapacity, srcSize); +} + +/*! ZSTD_compress_generic() : +* Compress a chunk of data into one or multiple blocks. +* All blocks will be terminated, all input will be consumed. +* Function will issue an error if there is not enough `dstCapacity` to hold the compressed content. +* Frame is supposed already started (header already produced) +* @return : compressed size, or an error code +*/ +static size_t ZSTD_compress_generic(ZSTD_CCtx *cctx, void *dst, size_t dstCapacity, const void *src, size_t srcSize, U32 lastFrameChunk) +{ + size_t blockSize = cctx->blockSize; + size_t remaining = srcSize; + const BYTE *ip = (const BYTE *)src; + BYTE *const ostart = (BYTE *)dst; + BYTE *op = ostart; + U32 const maxDist = 1 << cctx->params.cParams.windowLog; + + if (cctx->params.fParams.checksumFlag && srcSize) + xxh64_update(&cctx->xxhState, src, srcSize); + + while (remaining) { + U32 const lastBlock = lastFrameChunk & (blockSize >= remaining); + size_t cSize; + + if (dstCapacity < ZSTD_blockHeaderSize + MIN_CBLOCK_SIZE) + return ERROR(dstSize_tooSmall); /* not enough space to store compressed block */ + if (remaining < blockSize) + blockSize = remaining; + + /* preemptive overflow correction */ + if (cctx->lowLimit > (3U << 29)) { + U32 const cycleMask = (1 << ZSTD_cycleLog(cctx->params.cParams.hashLog, cctx->params.cParams.strategy)) - 1; + U32 const curr = (U32)(ip - cctx->base); + U32 const newCurr = (curr & cycleMask) + (1 << cctx->params.cParams.windowLog); + U32 const correction = curr - newCurr; + ZSTD_STATIC_ASSERT(ZSTD_WINDOWLOG_MAX_64 <= 30); + ZSTD_reduceIndex(cctx, correction); + cctx->base += correction; + cctx->dictBase += correction; + cctx->lowLimit -= correction; + cctx->dictLimit -= correction; + if (cctx->nextToUpdate < correction) + cctx->nextToUpdate = 0; + else + cctx->nextToUpdate -= correction; + } + + if ((U32)(ip + blockSize - cctx->base) > cctx->loadedDictEnd + maxDist) { + /* enforce maxDist */ + U32 const newLowLimit = (U32)(ip + blockSize - cctx->base) - maxDist; + if (cctx->lowLimit < newLowLimit) + cctx->lowLimit = newLowLimit; + if (cctx->dictLimit < cctx->lowLimit) + cctx->dictLimit = cctx->lowLimit; + } + + cSize = ZSTD_compressBlock_internal(cctx, op + ZSTD_blockHeaderSize, dstCapacity - ZSTD_blockHeaderSize, ip, blockSize); + if (ZSTD_isError(cSize)) + return cSize; + + if (cSize == 0) { /* block is not compressible */ + U32 const cBlockHeader24 = lastBlock + (((U32)bt_raw) << 1) + (U32)(blockSize << 3); + if (blockSize + ZSTD_blockHeaderSize > dstCapacity) + return ERROR(dstSize_tooSmall); + ZSTD_writeLE32(op, cBlockHeader24); /* no pb, 4th byte will be overwritten */ + memcpy(op + ZSTD_blockHeaderSize, ip, blockSize); + cSize = ZSTD_blockHeaderSize + blockSize; + } else { + U32 const cBlockHeader24 = lastBlock + (((U32)bt_compressed) << 1) + (U32)(cSize << 3); + ZSTD_writeLE24(op, cBlockHeader24); + cSize += ZSTD_blockHeaderSize; + } + + remaining -= blockSize; + dstCapacity -= cSize; + ip += blockSize; + op += cSize; + } + + if (lastFrameChunk && (op > ostart)) + cctx->stage = ZSTDcs_ending; + return op - ostart; +} + +static size_t ZSTD_writeFrameHeader(void *dst, size_t dstCapacity, ZSTD_parameters params, U64 pledgedSrcSize, U32 dictID) +{ + BYTE *const op = (BYTE *)dst; + U32 const dictIDSizeCode = (dictID > 0) + (dictID >= 256) + (dictID >= 65536); /* 0-3 */ + U32 const checksumFlag = params.fParams.checksumFlag > 0; + U32 const windowSize = 1U << params.cParams.windowLog; + U32 const singleSegment = params.fParams.contentSizeFlag && (windowSize >= pledgedSrcSize); + BYTE const windowLogByte = (BYTE)((params.cParams.windowLog - ZSTD_WINDOWLOG_ABSOLUTEMIN) << 3); + U32 const fcsCode = + params.fParams.contentSizeFlag ? (pledgedSrcSize >= 256) + (pledgedSrcSize >= 65536 + 256) + (pledgedSrcSize >= 0xFFFFFFFFU) : 0; /* 0-3 */ + BYTE const frameHeaderDecriptionByte = (BYTE)(dictIDSizeCode + (checksumFlag << 2) + (singleSegment << 5) + (fcsCode << 6)); + size_t pos; + + if (dstCapacity < ZSTD_frameHeaderSize_max) + return ERROR(dstSize_tooSmall); + + ZSTD_writeLE32(dst, ZSTD_MAGICNUMBER); + op[4] = frameHeaderDecriptionByte; + pos = 5; + if (!singleSegment) + op[pos++] = windowLogByte; + switch (dictIDSizeCode) { + default: /* impossible */ + case 0: break; + case 1: + op[pos] = (BYTE)(dictID); + pos++; + break; + case 2: + ZSTD_writeLE16(op + pos, (U16)dictID); + pos += 2; + break; + case 3: + ZSTD_writeLE32(op + pos, dictID); + pos += 4; + break; + } + switch (fcsCode) { + default: /* impossible */ + case 0: + if (singleSegment) + op[pos++] = (BYTE)(pledgedSrcSize); + break; + case 1: + ZSTD_writeLE16(op + pos, (U16)(pledgedSrcSize - 256)); + pos += 2; + break; + case 2: + ZSTD_writeLE32(op + pos, (U32)(pledgedSrcSize)); + pos += 4; + break; + case 3: + ZSTD_writeLE64(op + pos, (U64)(pledgedSrcSize)); + pos += 8; + break; + } + return pos; +} + +static size_t ZSTD_compressContinue_internal(ZSTD_CCtx *cctx, void *dst, size_t dstCapacity, const void *src, size_t srcSize, U32 frame, U32 lastFrameChunk) +{ + const BYTE *const ip = (const BYTE *)src; + size_t fhSize = 0; + + if (cctx->stage == ZSTDcs_created) + return ERROR(stage_wrong); /* missing init (ZSTD_compressBegin) */ + + if (frame && (cctx->stage == ZSTDcs_init)) { + fhSize = ZSTD_writeFrameHeader(dst, dstCapacity, cctx->params, cctx->frameContentSize, cctx->dictID); + if (ZSTD_isError(fhSize)) + return fhSize; + dstCapacity -= fhSize; + dst = (char *)dst + fhSize; + cctx->stage = ZSTDcs_ongoing; + } + + /* Check if blocks follow each other */ + if (src != cctx->nextSrc) { + /* not contiguous */ + ptrdiff_t const delta = cctx->nextSrc - ip; + cctx->lowLimit = cctx->dictLimit; + cctx->dictLimit = (U32)(cctx->nextSrc - cctx->base); + cctx->dictBase = cctx->base; + cctx->base -= delta; + cctx->nextToUpdate = cctx->dictLimit; + if (cctx->dictLimit - cctx->lowLimit < HASH_READ_SIZE) + cctx->lowLimit = cctx->dictLimit; /* too small extDict */ + } + + /* if input and dictionary overlap : reduce dictionary (area presumed modified by input) */ + if ((ip + srcSize > cctx->dictBase + cctx->lowLimit) & (ip < cctx->dictBase + cctx->dictLimit)) { + ptrdiff_t const highInputIdx = (ip + srcSize) - cctx->dictBase; + U32 const lowLimitMax = (highInputIdx > (ptrdiff_t)cctx->dictLimit) ? cctx->dictLimit : (U32)highInputIdx; + cctx->lowLimit = lowLimitMax; + } + + cctx->nextSrc = ip + srcSize; + + if (srcSize) { + size_t const cSize = frame ? ZSTD_compress_generic(cctx, dst, dstCapacity, src, srcSize, lastFrameChunk) + : ZSTD_compressBlock_internal(cctx, dst, dstCapacity, src, srcSize); + if (ZSTD_isError(cSize)) + return cSize; + return cSize + fhSize; + } else + return fhSize; +} + +size_t ZSTD_compressContinue(ZSTD_CCtx *cctx, void *dst, size_t dstCapacity, const void *src, size_t srcSize) +{ + return ZSTD_compressContinue_internal(cctx, dst, dstCapacity, src, srcSize, 1, 0); +} + +size_t ZSTD_getBlockSizeMax(ZSTD_CCtx *cctx) { return MIN(ZSTD_BLOCKSIZE_ABSOLUTEMAX, 1 << cctx->params.cParams.windowLog); } + +size_t ZSTD_compressBlock(ZSTD_CCtx *cctx, void *dst, size_t dstCapacity, const void *src, size_t srcSize) +{ + size_t const blockSizeMax = ZSTD_getBlockSizeMax(cctx); + if (srcSize > blockSizeMax) + return ERROR(srcSize_wrong); + return ZSTD_compressContinue_internal(cctx, dst, dstCapacity, src, srcSize, 0, 0); +} + +/*! ZSTD_loadDictionaryContent() : + * @return : 0, or an error code + */ +static size_t ZSTD_loadDictionaryContent(ZSTD_CCtx *zc, const void *src, size_t srcSize) +{ + const BYTE *const ip = (const BYTE *)src; + const BYTE *const iend = ip + srcSize; + + /* input becomes curr prefix */ + zc->lowLimit = zc->dictLimit; + zc->dictLimit = (U32)(zc->nextSrc - zc->base); + zc->dictBase = zc->base; + zc->base += ip - zc->nextSrc; + zc->nextToUpdate = zc->dictLimit; + zc->loadedDictEnd = zc->forceWindow ? 0 : (U32)(iend - zc->base); + + zc->nextSrc = iend; + if (srcSize <= HASH_READ_SIZE) + return 0; + + switch (zc->params.cParams.strategy) { + case ZSTD_fast: ZSTD_fillHashTable(zc, iend, zc->params.cParams.searchLength); break; + + case ZSTD_dfast: ZSTD_fillDoubleHashTable(zc, iend, zc->params.cParams.searchLength); break; + + case ZSTD_greedy: + case ZSTD_lazy: + case ZSTD_lazy2: + if (srcSize >= HASH_READ_SIZE) + ZSTD_insertAndFindFirstIndex(zc, iend - HASH_READ_SIZE, zc->params.cParams.searchLength); + break; + + case ZSTD_btlazy2: + case ZSTD_btopt: + case ZSTD_btopt2: + if (srcSize >= HASH_READ_SIZE) + ZSTD_updateTree(zc, iend - HASH_READ_SIZE, iend, 1 << zc->params.cParams.searchLog, zc->params.cParams.searchLength); + break; + + default: + return ERROR(GENERIC); /* strategy doesn't exist; impossible */ + } + + zc->nextToUpdate = (U32)(iend - zc->base); + return 0; +} + +/* Dictionaries that assign zero probability to symbols that show up causes problems + when FSE encoding. Refuse dictionaries that assign zero probability to symbols + that we may encounter during compression. + NOTE: This behavior is not standard and could be improved in the future. */ +static size_t ZSTD_checkDictNCount(short *normalizedCounter, unsigned dictMaxSymbolValue, unsigned maxSymbolValue) +{ + U32 s; + if (dictMaxSymbolValue < maxSymbolValue) + return ERROR(dictionary_corrupted); + for (s = 0; s <= maxSymbolValue; ++s) { + if (normalizedCounter[s] == 0) + return ERROR(dictionary_corrupted); + } + return 0; +} + +/* Dictionary format : + * See : + * https://github.com/facebook/zstd/blob/master/doc/zstd_compression_format.md#dictionary-format + */ +/*! ZSTD_loadZstdDictionary() : + * @return : 0, or an error code + * assumptions : magic number supposed already checked + * dictSize supposed > 8 + */ +static size_t ZSTD_loadZstdDictionary(ZSTD_CCtx *cctx, const void *dict, size_t dictSize) +{ + const BYTE *dictPtr = (const BYTE *)dict; + const BYTE *const dictEnd = dictPtr + dictSize; + short offcodeNCount[MaxOff + 1]; + unsigned offcodeMaxValue = MaxOff; + + dictPtr += 4; /* skip magic number */ + cctx->dictID = cctx->params.fParams.noDictIDFlag ? 0 : ZSTD_readLE32(dictPtr); + dictPtr += 4; + + { + size_t const hufHeaderSize = HUF_readCTable_wksp(cctx->hufTable, 255, dictPtr, dictEnd - dictPtr, cctx->tmpCounters, sizeof(cctx->tmpCounters)); + if (HUF_isError(hufHeaderSize)) + return ERROR(dictionary_corrupted); + dictPtr += hufHeaderSize; + } + + { + unsigned offcodeLog; + size_t const offcodeHeaderSize = FSE_readNCount(offcodeNCount, &offcodeMaxValue, &offcodeLog, dictPtr, dictEnd - dictPtr); + if (FSE_isError(offcodeHeaderSize)) + return ERROR(dictionary_corrupted); + if (offcodeLog > OffFSELog) + return ERROR(dictionary_corrupted); + /* Defer checking offcodeMaxValue because we need to know the size of the dictionary content */ + CHECK_E(FSE_buildCTable_wksp(cctx->offcodeCTable, offcodeNCount, offcodeMaxValue, offcodeLog, cctx->tmpCounters, sizeof(cctx->tmpCounters)), + dictionary_corrupted); + dictPtr += offcodeHeaderSize; + } + + { + short matchlengthNCount[MaxML + 1]; + unsigned matchlengthMaxValue = MaxML, matchlengthLog; + size_t const matchlengthHeaderSize = FSE_readNCount(matchlengthNCount, &matchlengthMaxValue, &matchlengthLog, dictPtr, dictEnd - dictPtr); + if (FSE_isError(matchlengthHeaderSize)) + return ERROR(dictionary_corrupted); + if (matchlengthLog > MLFSELog) + return ERROR(dictionary_corrupted); + /* Every match length code must have non-zero probability */ + CHECK_F(ZSTD_checkDictNCount(matchlengthNCount, matchlengthMaxValue, MaxML)); + CHECK_E( + FSE_buildCTable_wksp(cctx->matchlengthCTable, matchlengthNCount, matchlengthMaxValue, matchlengthLog, cctx->tmpCounters, sizeof(cctx->tmpCounters)), + dictionary_corrupted); + dictPtr += matchlengthHeaderSize; + } + + { + short litlengthNCount[MaxLL + 1]; + unsigned litlengthMaxValue = MaxLL, litlengthLog; + size_t const litlengthHeaderSize = FSE_readNCount(litlengthNCount, &litlengthMaxValue, &litlengthLog, dictPtr, dictEnd - dictPtr); + if (FSE_isError(litlengthHeaderSize)) + return ERROR(dictionary_corrupted); + if (litlengthLog > LLFSELog) + return ERROR(dictionary_corrupted); + /* Every literal length code must have non-zero probability */ + CHECK_F(ZSTD_checkDictNCount(litlengthNCount, litlengthMaxValue, MaxLL)); + CHECK_E(FSE_buildCTable_wksp(cctx->litlengthCTable, litlengthNCount, litlengthMaxValue, litlengthLog, cctx->tmpCounters, sizeof(cctx->tmpCounters)), + dictionary_corrupted); + dictPtr += litlengthHeaderSize; + } + + if (dictPtr + 12 > dictEnd) + return ERROR(dictionary_corrupted); + cctx->rep[0] = ZSTD_readLE32(dictPtr + 0); + cctx->rep[1] = ZSTD_readLE32(dictPtr + 4); + cctx->rep[2] = ZSTD_readLE32(dictPtr + 8); + dictPtr += 12; + + { + size_t const dictContentSize = (size_t)(dictEnd - dictPtr); + U32 offcodeMax = MaxOff; + if (dictContentSize <= ((U32)-1) - 128 KB) { + U32 const maxOffset = (U32)dictContentSize + 128 KB; /* The maximum offset that must be supported */ + offcodeMax = ZSTD_highbit32(maxOffset); /* Calculate minimum offset code required to represent maxOffset */ + } + /* All offset values <= dictContentSize + 128 KB must be representable */ + CHECK_F(ZSTD_checkDictNCount(offcodeNCount, offcodeMaxValue, MIN(offcodeMax, MaxOff))); + /* All repCodes must be <= dictContentSize and != 0*/ + { + U32 u; + for (u = 0; u < 3; u++) { + if (cctx->rep[u] == 0) + return ERROR(dictionary_corrupted); + if (cctx->rep[u] > dictContentSize) + return ERROR(dictionary_corrupted); + } + } + + cctx->flagStaticTables = 1; + cctx->flagStaticHufTable = HUF_repeat_valid; + return ZSTD_loadDictionaryContent(cctx, dictPtr, dictContentSize); + } +} + +/** ZSTD_compress_insertDictionary() : +* @return : 0, or an error code */ +static size_t ZSTD_compress_insertDictionary(ZSTD_CCtx *cctx, const void *dict, size_t dictSize) +{ + if ((dict == NULL) || (dictSize <= 8)) + return 0; + + /* dict as pure content */ + if ((ZSTD_readLE32(dict) != ZSTD_DICT_MAGIC) || (cctx->forceRawDict)) + return ZSTD_loadDictionaryContent(cctx, dict, dictSize); + + /* dict as zstd dictionary */ + return ZSTD_loadZstdDictionary(cctx, dict, dictSize); +} + +/*! ZSTD_compressBegin_internal() : +* @return : 0, or an error code */ +static size_t ZSTD_compressBegin_internal(ZSTD_CCtx *cctx, const void *dict, size_t dictSize, ZSTD_parameters params, U64 pledgedSrcSize) +{ + ZSTD_compResetPolicy_e const crp = dictSize ? ZSTDcrp_fullReset : ZSTDcrp_continue; + CHECK_F(ZSTD_resetCCtx_advanced(cctx, params, pledgedSrcSize, crp)); + return ZSTD_compress_insertDictionary(cctx, dict, dictSize); +} + +/*! ZSTD_compressBegin_advanced() : +* @return : 0, or an error code */ +size_t ZSTD_compressBegin_advanced(ZSTD_CCtx *cctx, const void *dict, size_t dictSize, ZSTD_parameters params, unsigned long long pledgedSrcSize) +{ + /* compression parameters verification and optimization */ + CHECK_F(ZSTD_checkCParams(params.cParams)); + return ZSTD_compressBegin_internal(cctx, dict, dictSize, params, pledgedSrcSize); +} + +size_t ZSTD_compressBegin_usingDict(ZSTD_CCtx *cctx, const void *dict, size_t dictSize, int compressionLevel) +{ + ZSTD_parameters const params = ZSTD_getParams(compressionLevel, 0, dictSize); + return ZSTD_compressBegin_internal(cctx, dict, dictSize, params, 0); +} + +size_t ZSTD_compressBegin(ZSTD_CCtx *cctx, int compressionLevel) { return ZSTD_compressBegin_usingDict(cctx, NULL, 0, compressionLevel); } + +/*! ZSTD_writeEpilogue() : +* Ends a frame. +* @return : nb of bytes written into dst (or an error code) */ +static size_t ZSTD_writeEpilogue(ZSTD_CCtx *cctx, void *dst, size_t dstCapacity) +{ + BYTE *const ostart = (BYTE *)dst; + BYTE *op = ostart; + size_t fhSize = 0; + + if (cctx->stage == ZSTDcs_created) + return ERROR(stage_wrong); /* init missing */ + + /* special case : empty frame */ + if (cctx->stage == ZSTDcs_init) { + fhSize = ZSTD_writeFrameHeader(dst, dstCapacity, cctx->params, 0, 0); + if (ZSTD_isError(fhSize)) + return fhSize; + dstCapacity -= fhSize; + op += fhSize; + cctx->stage = ZSTDcs_ongoing; + } + + if (cctx->stage != ZSTDcs_ending) { + /* write one last empty block, make it the "last" block */ + U32 const cBlockHeader24 = 1 /* last block */ + (((U32)bt_raw) << 1) + 0; + if (dstCapacity < 4) + return ERROR(dstSize_tooSmall); + ZSTD_writeLE32(op, cBlockHeader24); + op += ZSTD_blockHeaderSize; + dstCapacity -= ZSTD_blockHeaderSize; + } + + if (cctx->params.fParams.checksumFlag) { + U32 const checksum = (U32)xxh64_digest(&cctx->xxhState); + if (dstCapacity < 4) + return ERROR(dstSize_tooSmall); + ZSTD_writeLE32(op, checksum); + op += 4; + } + + cctx->stage = ZSTDcs_created; /* return to "created but no init" status */ + return op - ostart; +} + +size_t ZSTD_compressEnd(ZSTD_CCtx *cctx, void *dst, size_t dstCapacity, const void *src, size_t srcSize) +{ + size_t endResult; + size_t const cSize = ZSTD_compressContinue_internal(cctx, dst, dstCapacity, src, srcSize, 1, 1); + if (ZSTD_isError(cSize)) + return cSize; + endResult = ZSTD_writeEpilogue(cctx, (char *)dst + cSize, dstCapacity - cSize); + if (ZSTD_isError(endResult)) + return endResult; + return cSize + endResult; +} + +static size_t ZSTD_compress_internal(ZSTD_CCtx *cctx, void *dst, size_t dstCapacity, const void *src, size_t srcSize, const void *dict, size_t dictSize, + ZSTD_parameters params) +{ + CHECK_F(ZSTD_compressBegin_internal(cctx, dict, dictSize, params, srcSize)); + return ZSTD_compressEnd(cctx, dst, dstCapacity, src, srcSize); +} + +size_t ZSTD_compress_usingDict(ZSTD_CCtx *ctx, void *dst, size_t dstCapacity, const void *src, size_t srcSize, const void *dict, size_t dictSize, + ZSTD_parameters params) +{ + return ZSTD_compress_internal(ctx, dst, dstCapacity, src, srcSize, dict, dictSize, params); +} + +size_t ZSTD_compressCCtx(ZSTD_CCtx *ctx, void *dst, size_t dstCapacity, const void *src, size_t srcSize, ZSTD_parameters params) +{ + return ZSTD_compress_internal(ctx, dst, dstCapacity, src, srcSize, NULL, 0, params); +} + +/* ===== Dictionary API ===== */ + +struct ZSTD_CDict_s { + void *dictBuffer; + const void *dictContent; + size_t dictContentSize; + ZSTD_CCtx *refContext; +}; /* typedef'd tp ZSTD_CDict within "zstd.h" */ + +size_t ZSTD_CDictWorkspaceBound(ZSTD_compressionParameters cParams) { return ZSTD_CCtxWorkspaceBound(cParams) + ZSTD_ALIGN(sizeof(ZSTD_CDict)); } + +static ZSTD_CDict *ZSTD_createCDict_advanced(const void *dictBuffer, size_t dictSize, unsigned byReference, ZSTD_parameters params, ZSTD_customMem customMem) +{ + if (!customMem.customAlloc || !customMem.customFree) + return NULL; + + { + ZSTD_CDict *const cdict = (ZSTD_CDict *)ZSTD_malloc(sizeof(ZSTD_CDict), customMem); + ZSTD_CCtx *const cctx = ZSTD_createCCtx_advanced(customMem); + + if (!cdict || !cctx) { + ZSTD_free(cdict, customMem); + ZSTD_freeCCtx(cctx); + return NULL; + } + + if ((byReference) || (!dictBuffer) || (!dictSize)) { + cdict->dictBuffer = NULL; + cdict->dictContent = dictBuffer; + } else { + void *const internalBuffer = ZSTD_malloc(dictSize, customMem); + if (!internalBuffer) { + ZSTD_free(cctx, customMem); + ZSTD_free(cdict, customMem); + return NULL; + } + memcpy(internalBuffer, dictBuffer, dictSize); + cdict->dictBuffer = internalBuffer; + cdict->dictContent = internalBuffer; + } + + { + size_t const errorCode = ZSTD_compressBegin_advanced(cctx, cdict->dictContent, dictSize, params, 0); + if (ZSTD_isError(errorCode)) { + ZSTD_free(cdict->dictBuffer, customMem); + ZSTD_free(cdict, customMem); + ZSTD_freeCCtx(cctx); + return NULL; + } + } + + cdict->refContext = cctx; + cdict->dictContentSize = dictSize; + return cdict; + } +} + +ZSTD_CDict *ZSTD_initCDict(const void *dict, size_t dictSize, ZSTD_parameters params, void *workspace, size_t workspaceSize) +{ + ZSTD_customMem const stackMem = ZSTD_initStack(workspace, workspaceSize); + return ZSTD_createCDict_advanced(dict, dictSize, 1, params, stackMem); +} + +size_t ZSTD_freeCDict(ZSTD_CDict *cdict) +{ + if (cdict == NULL) + return 0; /* support free on NULL */ + { + ZSTD_customMem const cMem = cdict->refContext->customMem; + ZSTD_freeCCtx(cdict->refContext); + ZSTD_free(cdict->dictBuffer, cMem); + ZSTD_free(cdict, cMem); + return 0; + } +} + +static ZSTD_parameters ZSTD_getParamsFromCDict(const ZSTD_CDict *cdict) { return ZSTD_getParamsFromCCtx(cdict->refContext); } + +size_t ZSTD_compressBegin_usingCDict(ZSTD_CCtx *cctx, const ZSTD_CDict *cdict, unsigned long long pledgedSrcSize) +{ + if (cdict->dictContentSize) + CHECK_F(ZSTD_copyCCtx(cctx, cdict->refContext, pledgedSrcSize)) + else { + ZSTD_parameters params = cdict->refContext->params; + params.fParams.contentSizeFlag = (pledgedSrcSize > 0); + CHECK_F(ZSTD_compressBegin_advanced(cctx, NULL, 0, params, pledgedSrcSize)); + } + return 0; +} + +/*! ZSTD_compress_usingCDict() : +* Compression using a digested Dictionary. +* Faster startup than ZSTD_compress_usingDict(), recommended when same dictionary is used multiple times. +* Note that compression level is decided during dictionary creation */ +size_t ZSTD_compress_usingCDict(ZSTD_CCtx *cctx, void *dst, size_t dstCapacity, const void *src, size_t srcSize, const ZSTD_CDict *cdict) +{ + CHECK_F(ZSTD_compressBegin_usingCDict(cctx, cdict, srcSize)); + + if (cdict->refContext->params.fParams.contentSizeFlag == 1) { + cctx->params.fParams.contentSizeFlag = 1; + cctx->frameContentSize = srcSize; + } else { + cctx->params.fParams.contentSizeFlag = 0; + } + + return ZSTD_compressEnd(cctx, dst, dstCapacity, src, srcSize); +} + +/* ****************************************************************** +* Streaming +********************************************************************/ + +typedef enum { zcss_init, zcss_load, zcss_flush, zcss_final } ZSTD_cStreamStage; + +struct ZSTD_CStream_s { + ZSTD_CCtx *cctx; + ZSTD_CDict *cdictLocal; + const ZSTD_CDict *cdict; + char *inBuff; + size_t inBuffSize; + size_t inToCompress; + size_t inBuffPos; + size_t inBuffTarget; + size_t blockSize; + char *outBuff; + size_t outBuffSize; + size_t outBuffContentSize; + size_t outBuffFlushedSize; + ZSTD_cStreamStage stage; + U32 checksum; + U32 frameEnded; + U64 pledgedSrcSize; + U64 inputProcessed; + ZSTD_parameters params; + ZSTD_customMem customMem; +}; /* typedef'd to ZSTD_CStream within "zstd.h" */ + +size_t ZSTD_CStreamWorkspaceBound(ZSTD_compressionParameters cParams) +{ + size_t const inBuffSize = (size_t)1 << cParams.windowLog; + size_t const blockSize = MIN(ZSTD_BLOCKSIZE_ABSOLUTEMAX, inBuffSize); + size_t const outBuffSize = ZSTD_compressBound(blockSize) + 1; + + return ZSTD_CCtxWorkspaceBound(cParams) + ZSTD_ALIGN(sizeof(ZSTD_CStream)) + ZSTD_ALIGN(inBuffSize) + ZSTD_ALIGN(outBuffSize); +} + +ZSTD_CStream *ZSTD_createCStream_advanced(ZSTD_customMem customMem) +{ + ZSTD_CStream *zcs; + + if (!customMem.customAlloc || !customMem.customFree) + return NULL; + + zcs = (ZSTD_CStream *)ZSTD_malloc(sizeof(ZSTD_CStream), customMem); + if (zcs == NULL) + return NULL; + memset(zcs, 0, sizeof(ZSTD_CStream)); + memcpy(&zcs->customMem, &customMem, sizeof(ZSTD_customMem)); + zcs->cctx = ZSTD_createCCtx_advanced(customMem); + if (zcs->cctx == NULL) { + ZSTD_freeCStream(zcs); + return NULL; + } + return zcs; +} + +size_t ZSTD_freeCStream(ZSTD_CStream *zcs) +{ + if (zcs == NULL) + return 0; /* support free on NULL */ + { + ZSTD_customMem const cMem = zcs->customMem; + ZSTD_freeCCtx(zcs->cctx); + zcs->cctx = NULL; + ZSTD_freeCDict(zcs->cdictLocal); + zcs->cdictLocal = NULL; + ZSTD_free(zcs->inBuff, cMem); + zcs->inBuff = NULL; + ZSTD_free(zcs->outBuff, cMem); + zcs->outBuff = NULL; + ZSTD_free(zcs, cMem); + return 0; + } +} + +/*====== Initialization ======*/ + +size_t ZSTD_CStreamInSize(void) { return ZSTD_BLOCKSIZE_ABSOLUTEMAX; } +size_t ZSTD_CStreamOutSize(void) { return ZSTD_compressBound(ZSTD_BLOCKSIZE_ABSOLUTEMAX) + ZSTD_blockHeaderSize + 4 /* 32-bits hash */; } + +static size_t ZSTD_resetCStream_internal(ZSTD_CStream *zcs, unsigned long long pledgedSrcSize) +{ + if (zcs->inBuffSize == 0) + return ERROR(stage_wrong); /* zcs has not been init at least once => can't reset */ + + if (zcs->cdict) + CHECK_F(ZSTD_compressBegin_usingCDict(zcs->cctx, zcs->cdict, pledgedSrcSize)) + else + CHECK_F(ZSTD_compressBegin_advanced(zcs->cctx, NULL, 0, zcs->params, pledgedSrcSize)); + + zcs->inToCompress = 0; + zcs->inBuffPos = 0; + zcs->inBuffTarget = zcs->blockSize; + zcs->outBuffContentSize = zcs->outBuffFlushedSize = 0; + zcs->stage = zcss_load; + zcs->frameEnded = 0; + zcs->pledgedSrcSize = pledgedSrcSize; + zcs->inputProcessed = 0; + return 0; /* ready to go */ +} + +size_t ZSTD_resetCStream(ZSTD_CStream *zcs, unsigned long long pledgedSrcSize) +{ + + zcs->params.fParams.contentSizeFlag = (pledgedSrcSize > 0); + + return ZSTD_resetCStream_internal(zcs, pledgedSrcSize); +} + +static size_t ZSTD_initCStream_advanced(ZSTD_CStream *zcs, const void *dict, size_t dictSize, ZSTD_parameters params, unsigned long long pledgedSrcSize) +{ + /* allocate buffers */ + { + size_t const neededInBuffSize = (size_t)1 << params.cParams.windowLog; + if (zcs->inBuffSize < neededInBuffSize) { + zcs->inBuffSize = neededInBuffSize; + ZSTD_free(zcs->inBuff, zcs->customMem); + zcs->inBuff = (char *)ZSTD_malloc(neededInBuffSize, zcs->customMem); + if (zcs->inBuff == NULL) + return ERROR(memory_allocation); + } + zcs->blockSize = MIN(ZSTD_BLOCKSIZE_ABSOLUTEMAX, neededInBuffSize); + } + if (zcs->outBuffSize < ZSTD_compressBound(zcs->blockSize) + 1) { + zcs->outBuffSize = ZSTD_compressBound(zcs->blockSize) + 1; + ZSTD_free(zcs->outBuff, zcs->customMem); + zcs->outBuff = (char *)ZSTD_malloc(zcs->outBuffSize, zcs->customMem); + if (zcs->outBuff == NULL) + return ERROR(memory_allocation); + } + + if (dict && dictSize >= 8) { + ZSTD_freeCDict(zcs->cdictLocal); + zcs->cdictLocal = ZSTD_createCDict_advanced(dict, dictSize, 0, params, zcs->customMem); + if (zcs->cdictLocal == NULL) + return ERROR(memory_allocation); + zcs->cdict = zcs->cdictLocal; + } else + zcs->cdict = NULL; + + zcs->checksum = params.fParams.checksumFlag > 0; + zcs->params = params; + + return ZSTD_resetCStream_internal(zcs, pledgedSrcSize); +} + +ZSTD_CStream *ZSTD_initCStream(ZSTD_parameters params, unsigned long long pledgedSrcSize, void *workspace, size_t workspaceSize) +{ + ZSTD_customMem const stackMem = ZSTD_initStack(workspace, workspaceSize); + ZSTD_CStream *const zcs = ZSTD_createCStream_advanced(stackMem); + if (zcs) { + size_t const code = ZSTD_initCStream_advanced(zcs, NULL, 0, params, pledgedSrcSize); + if (ZSTD_isError(code)) { + return NULL; + } + } + return zcs; +} + +ZSTD_CStream *ZSTD_initCStream_usingCDict(const ZSTD_CDict *cdict, unsigned long long pledgedSrcSize, void *workspace, size_t workspaceSize) +{ + ZSTD_parameters const params = ZSTD_getParamsFromCDict(cdict); + ZSTD_CStream *const zcs = ZSTD_initCStream(params, pledgedSrcSize, workspace, workspaceSize); + if (zcs) { + zcs->cdict = cdict; + if (ZSTD_isError(ZSTD_resetCStream_internal(zcs, pledgedSrcSize))) { + return NULL; + } + } + return zcs; +} + +/*====== Compression ======*/ + +typedef enum { zsf_gather, zsf_flush, zsf_end } ZSTD_flush_e; + +ZSTD_STATIC size_t ZSTD_limitCopy(void *dst, size_t dstCapacity, const void *src, size_t srcSize) +{ + size_t const length = MIN(dstCapacity, srcSize); + memcpy(dst, src, length); + return length; +} + +static size_t ZSTD_compressStream_generic(ZSTD_CStream *zcs, void *dst, size_t *dstCapacityPtr, const void *src, size_t *srcSizePtr, ZSTD_flush_e const flush) +{ + U32 someMoreWork = 1; + const char *const istart = (const char *)src; + const char *const iend = istart + *srcSizePtr; + const char *ip = istart; + char *const ostart = (char *)dst; + char *const oend = ostart + *dstCapacityPtr; + char *op = ostart; + + while (someMoreWork) { + switch (zcs->stage) { + case zcss_init: + return ERROR(init_missing); /* call ZBUFF_compressInit() first ! */ + + case zcss_load: + /* complete inBuffer */ + { + size_t const toLoad = zcs->inBuffTarget - zcs->inBuffPos; + size_t const loaded = ZSTD_limitCopy(zcs->inBuff + zcs->inBuffPos, toLoad, ip, iend - ip); + zcs->inBuffPos += loaded; + ip += loaded; + if ((zcs->inBuffPos == zcs->inToCompress) || (!flush && (toLoad != loaded))) { + someMoreWork = 0; + break; /* not enough input to get a full block : stop there, wait for more */ + } + } + /* compress curr block (note : this stage cannot be stopped in the middle) */ + { + void *cDst; + size_t cSize; + size_t const iSize = zcs->inBuffPos - zcs->inToCompress; + size_t oSize = oend - op; + if (oSize >= ZSTD_compressBound(iSize)) + cDst = op; /* compress directly into output buffer (avoid flush stage) */ + else + cDst = zcs->outBuff, oSize = zcs->outBuffSize; + cSize = (flush == zsf_end) ? ZSTD_compressEnd(zcs->cctx, cDst, oSize, zcs->inBuff + zcs->inToCompress, iSize) + : ZSTD_compressContinue(zcs->cctx, cDst, oSize, zcs->inBuff + zcs->inToCompress, iSize); + if (ZSTD_isError(cSize)) + return cSize; + if (flush == zsf_end) + zcs->frameEnded = 1; + /* prepare next block */ + zcs->inBuffTarget = zcs->inBuffPos + zcs->blockSize; + if (zcs->inBuffTarget > zcs->inBuffSize) + zcs->inBuffPos = 0, zcs->inBuffTarget = zcs->blockSize; /* note : inBuffSize >= blockSize */ + zcs->inToCompress = zcs->inBuffPos; + if (cDst == op) { + op += cSize; + break; + } /* no need to flush */ + zcs->outBuffContentSize = cSize; + zcs->outBuffFlushedSize = 0; + zcs->stage = zcss_flush; /* pass-through to flush stage */ + } + + case zcss_flush: { + size_t const toFlush = zcs->outBuffContentSize - zcs->outBuffFlushedSize; + size_t const flushed = ZSTD_limitCopy(op, oend - op, zcs->outBuff + zcs->outBuffFlushedSize, toFlush); + op += flushed; + zcs->outBuffFlushedSize += flushed; + if (toFlush != flushed) { + someMoreWork = 0; + break; + } /* dst too small to store flushed data : stop there */ + zcs->outBuffContentSize = zcs->outBuffFlushedSize = 0; + zcs->stage = zcss_load; + break; + } + + case zcss_final: + someMoreWork = 0; /* do nothing */ + break; + + default: + return ERROR(GENERIC); /* impossible */ + } + } + + *srcSizePtr = ip - istart; + *dstCapacityPtr = op - ostart; + zcs->inputProcessed += *srcSizePtr; + if (zcs->frameEnded) + return 0; + { + size_t hintInSize = zcs->inBuffTarget - zcs->inBuffPos; + if (hintInSize == 0) + hintInSize = zcs->blockSize; + return hintInSize; + } +} + +size_t ZSTD_compressStream(ZSTD_CStream *zcs, ZSTD_outBuffer *output, ZSTD_inBuffer *input) +{ + size_t sizeRead = input->size - input->pos; + size_t sizeWritten = output->size - output->pos; + size_t const result = + ZSTD_compressStream_generic(zcs, (char *)(output->dst) + output->pos, &sizeWritten, (const char *)(input->src) + input->pos, &sizeRead, zsf_gather); + input->pos += sizeRead; + output->pos += sizeWritten; + return result; +} + +/*====== Finalize ======*/ + +/*! ZSTD_flushStream() : +* @return : amount of data remaining to flush */ +size_t ZSTD_flushStream(ZSTD_CStream *zcs, ZSTD_outBuffer *output) +{ + size_t srcSize = 0; + size_t sizeWritten = output->size - output->pos; + size_t const result = ZSTD_compressStream_generic(zcs, (char *)(output->dst) + output->pos, &sizeWritten, &srcSize, + &srcSize, /* use a valid src address instead of NULL */ + zsf_flush); + output->pos += sizeWritten; + if (ZSTD_isError(result)) + return result; + return zcs->outBuffContentSize - zcs->outBuffFlushedSize; /* remaining to flush */ +} + +size_t ZSTD_endStream(ZSTD_CStream *zcs, ZSTD_outBuffer *output) +{ + BYTE *const ostart = (BYTE *)(output->dst) + output->pos; + BYTE *const oend = (BYTE *)(output->dst) + output->size; + BYTE *op = ostart; + + if ((zcs->pledgedSrcSize) && (zcs->inputProcessed != zcs->pledgedSrcSize)) + return ERROR(srcSize_wrong); /* pledgedSrcSize not respected */ + + if (zcs->stage != zcss_final) { + /* flush whatever remains */ + size_t srcSize = 0; + size_t sizeWritten = output->size - output->pos; + size_t const notEnded = + ZSTD_compressStream_generic(zcs, ostart, &sizeWritten, &srcSize, &srcSize, zsf_end); /* use a valid src address instead of NULL */ + size_t const remainingToFlush = zcs->outBuffContentSize - zcs->outBuffFlushedSize; + op += sizeWritten; + if (remainingToFlush) { + output->pos += sizeWritten; + return remainingToFlush + ZSTD_BLOCKHEADERSIZE /* final empty block */ + (zcs->checksum * 4); + } + /* create epilogue */ + zcs->stage = zcss_final; + zcs->outBuffContentSize = !notEnded ? 0 : ZSTD_compressEnd(zcs->cctx, zcs->outBuff, zcs->outBuffSize, NULL, + 0); /* write epilogue, including final empty block, into outBuff */ + } + + /* flush epilogue */ + { + size_t const toFlush = zcs->outBuffContentSize - zcs->outBuffFlushedSize; + size_t const flushed = ZSTD_limitCopy(op, oend - op, zcs->outBuff + zcs->outBuffFlushedSize, toFlush); + op += flushed; + zcs->outBuffFlushedSize += flushed; + output->pos += op - ostart; + if (toFlush == flushed) + zcs->stage = zcss_init; /* end reached */ + return toFlush - flushed; + } +} + +/*-===== Pre-defined compression levels =====-*/ + +#define ZSTD_DEFAULT_CLEVEL 1 +#define ZSTD_MAX_CLEVEL 22 +int ZSTD_maxCLevel(void) { return ZSTD_MAX_CLEVEL; } + +static const ZSTD_compressionParameters ZSTD_defaultCParameters[4][ZSTD_MAX_CLEVEL + 1] = { + { + /* "default" */ + /* W, C, H, S, L, TL, strat */ + {18, 12, 12, 1, 7, 16, ZSTD_fast}, /* level 0 - never used */ + {19, 13, 14, 1, 7, 16, ZSTD_fast}, /* level 1 */ + {19, 15, 16, 1, 6, 16, ZSTD_fast}, /* level 2 */ + {20, 16, 17, 1, 5, 16, ZSTD_dfast}, /* level 3.*/ + {20, 18, 18, 1, 5, 16, ZSTD_dfast}, /* level 4.*/ + {20, 15, 18, 3, 5, 16, ZSTD_greedy}, /* level 5 */ + {21, 16, 19, 2, 5, 16, ZSTD_lazy}, /* level 6 */ + {21, 17, 20, 3, 5, 16, ZSTD_lazy}, /* level 7 */ + {21, 18, 20, 3, 5, 16, ZSTD_lazy2}, /* level 8 */ + {21, 20, 20, 3, 5, 16, ZSTD_lazy2}, /* level 9 */ + {21, 19, 21, 4, 5, 16, ZSTD_lazy2}, /* level 10 */ + {22, 20, 22, 4, 5, 16, ZSTD_lazy2}, /* level 11 */ + {22, 20, 22, 5, 5, 16, ZSTD_lazy2}, /* level 12 */ + {22, 21, 22, 5, 5, 16, ZSTD_lazy2}, /* level 13 */ + {22, 21, 22, 6, 5, 16, ZSTD_lazy2}, /* level 14 */ + {22, 21, 21, 5, 5, 16, ZSTD_btlazy2}, /* level 15 */ + {23, 22, 22, 5, 5, 16, ZSTD_btlazy2}, /* level 16 */ + {23, 21, 22, 4, 5, 24, ZSTD_btopt}, /* level 17 */ + {23, 23, 22, 6, 5, 32, ZSTD_btopt}, /* level 18 */ + {23, 23, 22, 6, 3, 48, ZSTD_btopt}, /* level 19 */ + {25, 25, 23, 7, 3, 64, ZSTD_btopt2}, /* level 20 */ + {26, 26, 23, 7, 3, 256, ZSTD_btopt2}, /* level 21 */ + {27, 27, 25, 9, 3, 512, ZSTD_btopt2}, /* level 22 */ + }, + { + /* for srcSize <= 256 KB */ + /* W, C, H, S, L, T, strat */ + {0, 0, 0, 0, 0, 0, ZSTD_fast}, /* level 0 - not used */ + {18, 13, 14, 1, 6, 8, ZSTD_fast}, /* level 1 */ + {18, 14, 13, 1, 5, 8, ZSTD_dfast}, /* level 2 */ + {18, 16, 15, 1, 5, 8, ZSTD_dfast}, /* level 3 */ + {18, 15, 17, 1, 5, 8, ZSTD_greedy}, /* level 4.*/ + {18, 16, 17, 4, 5, 8, ZSTD_greedy}, /* level 5.*/ + {18, 16, 17, 3, 5, 8, ZSTD_lazy}, /* level 6.*/ + {18, 17, 17, 4, 4, 8, ZSTD_lazy}, /* level 7 */ + {18, 17, 17, 4, 4, 8, ZSTD_lazy2}, /* level 8 */ + {18, 17, 17, 5, 4, 8, ZSTD_lazy2}, /* level 9 */ + {18, 17, 17, 6, 4, 8, ZSTD_lazy2}, /* level 10 */ + {18, 18, 17, 6, 4, 8, ZSTD_lazy2}, /* level 11.*/ + {18, 18, 17, 7, 4, 8, ZSTD_lazy2}, /* level 12.*/ + {18, 19, 17, 6, 4, 8, ZSTD_btlazy2}, /* level 13 */ + {18, 18, 18, 4, 4, 16, ZSTD_btopt}, /* level 14.*/ + {18, 18, 18, 4, 3, 16, ZSTD_btopt}, /* level 15.*/ + {18, 19, 18, 6, 3, 32, ZSTD_btopt}, /* level 16.*/ + {18, 19, 18, 8, 3, 64, ZSTD_btopt}, /* level 17.*/ + {18, 19, 18, 9, 3, 128, ZSTD_btopt}, /* level 18.*/ + {18, 19, 18, 10, 3, 256, ZSTD_btopt}, /* level 19.*/ + {18, 19, 18, 11, 3, 512, ZSTD_btopt2}, /* level 20.*/ + {18, 19, 18, 12, 3, 512, ZSTD_btopt2}, /* level 21.*/ + {18, 19, 18, 13, 3, 512, ZSTD_btopt2}, /* level 22.*/ + }, + { + /* for srcSize <= 128 KB */ + /* W, C, H, S, L, T, strat */ + {17, 12, 12, 1, 7, 8, ZSTD_fast}, /* level 0 - not used */ + {17, 12, 13, 1, 6, 8, ZSTD_fast}, /* level 1 */ + {17, 13, 16, 1, 5, 8, ZSTD_fast}, /* level 2 */ + {17, 16, 16, 2, 5, 8, ZSTD_dfast}, /* level 3 */ + {17, 13, 15, 3, 4, 8, ZSTD_greedy}, /* level 4 */ + {17, 15, 17, 4, 4, 8, ZSTD_greedy}, /* level 5 */ + {17, 16, 17, 3, 4, 8, ZSTD_lazy}, /* level 6 */ + {17, 15, 17, 4, 4, 8, ZSTD_lazy2}, /* level 7 */ + {17, 17, 17, 4, 4, 8, ZSTD_lazy2}, /* level 8 */ + {17, 17, 17, 5, 4, 8, ZSTD_lazy2}, /* level 9 */ + {17, 17, 17, 6, 4, 8, ZSTD_lazy2}, /* level 10 */ + {17, 17, 17, 7, 4, 8, ZSTD_lazy2}, /* level 11 */ + {17, 17, 17, 8, 4, 8, ZSTD_lazy2}, /* level 12 */ + {17, 18, 17, 6, 4, 8, ZSTD_btlazy2}, /* level 13.*/ + {17, 17, 17, 7, 3, 8, ZSTD_btopt}, /* level 14.*/ + {17, 17, 17, 7, 3, 16, ZSTD_btopt}, /* level 15.*/ + {17, 18, 17, 7, 3, 32, ZSTD_btopt}, /* level 16.*/ + {17, 18, 17, 7, 3, 64, ZSTD_btopt}, /* level 17.*/ + {17, 18, 17, 7, 3, 256, ZSTD_btopt}, /* level 18.*/ + {17, 18, 17, 8, 3, 256, ZSTD_btopt}, /* level 19.*/ + {17, 18, 17, 9, 3, 256, ZSTD_btopt2}, /* level 20.*/ + {17, 18, 17, 10, 3, 256, ZSTD_btopt2}, /* level 21.*/ + {17, 18, 17, 11, 3, 512, ZSTD_btopt2}, /* level 22.*/ + }, + { + /* for srcSize <= 16 KB */ + /* W, C, H, S, L, T, strat */ + {14, 12, 12, 1, 7, 6, ZSTD_fast}, /* level 0 - not used */ + {14, 14, 14, 1, 6, 6, ZSTD_fast}, /* level 1 */ + {14, 14, 14, 1, 4, 6, ZSTD_fast}, /* level 2 */ + {14, 14, 14, 1, 4, 6, ZSTD_dfast}, /* level 3.*/ + {14, 14, 14, 4, 4, 6, ZSTD_greedy}, /* level 4.*/ + {14, 14, 14, 3, 4, 6, ZSTD_lazy}, /* level 5.*/ + {14, 14, 14, 4, 4, 6, ZSTD_lazy2}, /* level 6 */ + {14, 14, 14, 5, 4, 6, ZSTD_lazy2}, /* level 7 */ + {14, 14, 14, 6, 4, 6, ZSTD_lazy2}, /* level 8.*/ + {14, 15, 14, 6, 4, 6, ZSTD_btlazy2}, /* level 9.*/ + {14, 15, 14, 3, 3, 6, ZSTD_btopt}, /* level 10.*/ + {14, 15, 14, 6, 3, 8, ZSTD_btopt}, /* level 11.*/ + {14, 15, 14, 6, 3, 16, ZSTD_btopt}, /* level 12.*/ + {14, 15, 14, 6, 3, 24, ZSTD_btopt}, /* level 13.*/ + {14, 15, 15, 6, 3, 48, ZSTD_btopt}, /* level 14.*/ + {14, 15, 15, 6, 3, 64, ZSTD_btopt}, /* level 15.*/ + {14, 15, 15, 6, 3, 96, ZSTD_btopt}, /* level 16.*/ + {14, 15, 15, 6, 3, 128, ZSTD_btopt}, /* level 17.*/ + {14, 15, 15, 6, 3, 256, ZSTD_btopt}, /* level 18.*/ + {14, 15, 15, 7, 3, 256, ZSTD_btopt}, /* level 19.*/ + {14, 15, 15, 8, 3, 256, ZSTD_btopt2}, /* level 20.*/ + {14, 15, 15, 9, 3, 256, ZSTD_btopt2}, /* level 21.*/ + {14, 15, 15, 10, 3, 256, ZSTD_btopt2}, /* level 22.*/ + }, +}; + +/*! ZSTD_getCParams() : +* @return ZSTD_compressionParameters structure for a selected compression level, `srcSize` and `dictSize`. +* Size values are optional, provide 0 if not known or unused */ +ZSTD_compressionParameters ZSTD_getCParams(int compressionLevel, unsigned long long srcSize, size_t dictSize) +{ + ZSTD_compressionParameters cp; + size_t const addedSize = srcSize ? 0 : 500; + U64 const rSize = srcSize + dictSize ? srcSize + dictSize + addedSize : (U64)-1; + U32 const tableID = (rSize <= 256 KB) + (rSize <= 128 KB) + (rSize <= 16 KB); /* intentional underflow for srcSizeHint == 0 */ + if (compressionLevel <= 0) + compressionLevel = ZSTD_DEFAULT_CLEVEL; /* 0 == default; no negative compressionLevel yet */ + if (compressionLevel > ZSTD_MAX_CLEVEL) + compressionLevel = ZSTD_MAX_CLEVEL; + cp = ZSTD_defaultCParameters[tableID][compressionLevel]; + if (ZSTD_32bits()) { /* auto-correction, for 32-bits mode */ + if (cp.windowLog > ZSTD_WINDOWLOG_MAX) + cp.windowLog = ZSTD_WINDOWLOG_MAX; + if (cp.chainLog > ZSTD_CHAINLOG_MAX) + cp.chainLog = ZSTD_CHAINLOG_MAX; + if (cp.hashLog > ZSTD_HASHLOG_MAX) + cp.hashLog = ZSTD_HASHLOG_MAX; + } + cp = ZSTD_adjustCParams(cp, srcSize, dictSize); + return cp; +} + +/*! ZSTD_getParams() : +* same as ZSTD_getCParams(), but @return a `ZSTD_parameters` object (instead of `ZSTD_compressionParameters`). +* All fields of `ZSTD_frameParameters` are set to default (0) */ +ZSTD_parameters ZSTD_getParams(int compressionLevel, unsigned long long srcSize, size_t dictSize) +{ + ZSTD_parameters params; + ZSTD_compressionParameters const cParams = ZSTD_getCParams(compressionLevel, srcSize, dictSize); + memset(¶ms, 0, sizeof(params)); + params.cParams = cParams; + return params; +} + +EXPORT_SYMBOL(ZSTD_maxCLevel); +EXPORT_SYMBOL(ZSTD_compressBound); + +EXPORT_SYMBOL(ZSTD_CCtxWorkspaceBound); +EXPORT_SYMBOL(ZSTD_initCCtx); +EXPORT_SYMBOL(ZSTD_compressCCtx); +EXPORT_SYMBOL(ZSTD_compress_usingDict); + +EXPORT_SYMBOL(ZSTD_CDictWorkspaceBound); +EXPORT_SYMBOL(ZSTD_initCDict); +EXPORT_SYMBOL(ZSTD_compress_usingCDict); + +EXPORT_SYMBOL(ZSTD_CStreamWorkspaceBound); +EXPORT_SYMBOL(ZSTD_initCStream); +EXPORT_SYMBOL(ZSTD_initCStream_usingCDict); +EXPORT_SYMBOL(ZSTD_resetCStream); +EXPORT_SYMBOL(ZSTD_compressStream); +EXPORT_SYMBOL(ZSTD_flushStream); +EXPORT_SYMBOL(ZSTD_endStream); +EXPORT_SYMBOL(ZSTD_CStreamInSize); +EXPORT_SYMBOL(ZSTD_CStreamOutSize); + +EXPORT_SYMBOL(ZSTD_getCParams); +EXPORT_SYMBOL(ZSTD_getParams); +EXPORT_SYMBOL(ZSTD_checkCParams); +EXPORT_SYMBOL(ZSTD_adjustCParams); + +EXPORT_SYMBOL(ZSTD_compressBegin); +EXPORT_SYMBOL(ZSTD_compressBegin_usingDict); +EXPORT_SYMBOL(ZSTD_compressBegin_advanced); +EXPORT_SYMBOL(ZSTD_copyCCtx); +EXPORT_SYMBOL(ZSTD_compressBegin_usingCDict); +EXPORT_SYMBOL(ZSTD_compressContinue); +EXPORT_SYMBOL(ZSTD_compressEnd); + +EXPORT_SYMBOL(ZSTD_getBlockSizeMax); +EXPORT_SYMBOL(ZSTD_compressBlock); + +MODULE_LICENSE("Dual BSD/GPL"); +MODULE_DESCRIPTION("Zstd Compressor"); diff --git a/lib/zstd/decompress.c b/lib/zstd/decompress.c new file mode 100644 index 000000000000..b17846725ca0 --- /dev/null +++ b/lib/zstd/decompress.c @@ -0,0 +1,2528 @@ +/** + * Copyright (c) 2016-present, Yann Collet, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of https://github.com/facebook/zstd. + * An additional grant of patent rights can be found in the PATENTS file in the + * same directory. + * + * This program is free software; you can redistribute it and/or modify it under + * the terms of the GNU General Public License version 2 as published by the + * Free Software Foundation. This program is dual-licensed; you may select + * either version 2 of the GNU General Public License ("GPL") or BSD license + * ("BSD"). + */ + +/* *************************************************************** +* Tuning parameters +*****************************************************************/ +/*! +* MAXWINDOWSIZE_DEFAULT : +* maximum window size accepted by DStream, by default. +* Frames requiring more memory will be rejected. +*/ +#ifndef ZSTD_MAXWINDOWSIZE_DEFAULT +#define ZSTD_MAXWINDOWSIZE_DEFAULT ((1 << ZSTD_WINDOWLOG_MAX) + 1) /* defined within zstd.h */ +#endif + +/*-******************************************************* +* Dependencies +*********************************************************/ +#include "fse.h" +#include "huf.h" +#include "mem.h" /* low level memory routines */ +#include "zstd_internal.h" +#include +#include +#include /* memcpy, memmove, memset */ + +#define ZSTD_PREFETCH(ptr) __builtin_prefetch(ptr, 0, 0) + +/*-************************************* +* Macros +***************************************/ +#define ZSTD_isError ERR_isError /* for inlining */ +#define FSE_isError ERR_isError +#define HUF_isError ERR_isError + +/*_******************************************************* +* Memory operations +**********************************************************/ +static void ZSTD_copy4(void *dst, const void *src) { memcpy(dst, src, 4); } + +/*-************************************************************* +* Context management +***************************************************************/ +typedef enum { + ZSTDds_getFrameHeaderSize, + ZSTDds_decodeFrameHeader, + ZSTDds_decodeBlockHeader, + ZSTDds_decompressBlock, + ZSTDds_decompressLastBlock, + ZSTDds_checkChecksum, + ZSTDds_decodeSkippableHeader, + ZSTDds_skipFrame +} ZSTD_dStage; + +typedef struct { + FSE_DTable LLTable[FSE_DTABLE_SIZE_U32(LLFSELog)]; + FSE_DTable OFTable[FSE_DTABLE_SIZE_U32(OffFSELog)]; + FSE_DTable MLTable[FSE_DTABLE_SIZE_U32(MLFSELog)]; + HUF_DTable hufTable[HUF_DTABLE_SIZE(HufLog)]; /* can accommodate HUF_decompress4X */ + U64 workspace[HUF_DECOMPRESS_WORKSPACE_SIZE_U32 / 2]; + U32 rep[ZSTD_REP_NUM]; +} ZSTD_entropyTables_t; + +struct ZSTD_DCtx_s { + const FSE_DTable *LLTptr; + const FSE_DTable *MLTptr; + const FSE_DTable *OFTptr; + const HUF_DTable *HUFptr; + ZSTD_entropyTables_t entropy; + const void *previousDstEnd; /* detect continuity */ + const void *base; /* start of curr segment */ + const void *vBase; /* virtual start of previous segment if it was just before curr one */ + const void *dictEnd; /* end of previous segment */ + size_t expected; + ZSTD_frameParams fParams; + blockType_e bType; /* used in ZSTD_decompressContinue(), to transfer blockType between header decoding and block decoding stages */ + ZSTD_dStage stage; + U32 litEntropy; + U32 fseEntropy; + struct xxh64_state xxhState; + size_t headerSize; + U32 dictID; + const BYTE *litPtr; + ZSTD_customMem customMem; + size_t litSize; + size_t rleSize; + BYTE litBuffer[ZSTD_BLOCKSIZE_ABSOLUTEMAX + WILDCOPY_OVERLENGTH]; + BYTE headerBuffer[ZSTD_FRAMEHEADERSIZE_MAX]; +}; /* typedef'd to ZSTD_DCtx within "zstd.h" */ + +size_t ZSTD_DCtxWorkspaceBound(void) { return ZSTD_ALIGN(sizeof(ZSTD_stack)) + ZSTD_ALIGN(sizeof(ZSTD_DCtx)); } + +size_t ZSTD_decompressBegin(ZSTD_DCtx *dctx) +{ + dctx->expected = ZSTD_frameHeaderSize_prefix; + dctx->stage = ZSTDds_getFrameHeaderSize; + dctx->previousDstEnd = NULL; + dctx->base = NULL; + dctx->vBase = NULL; + dctx->dictEnd = NULL; + dctx->entropy.hufTable[0] = (HUF_DTable)((HufLog)*0x1000001); /* cover both little and big endian */ + dctx->litEntropy = dctx->fseEntropy = 0; + dctx->dictID = 0; + ZSTD_STATIC_ASSERT(sizeof(dctx->entropy.rep) == sizeof(repStartValue)); + memcpy(dctx->entropy.rep, repStartValue, sizeof(repStartValue)); /* initial repcodes */ + dctx->LLTptr = dctx->entropy.LLTable; + dctx->MLTptr = dctx->entropy.MLTable; + dctx->OFTptr = dctx->entropy.OFTable; + dctx->HUFptr = dctx->entropy.hufTable; + return 0; +} + +ZSTD_DCtx *ZSTD_createDCtx_advanced(ZSTD_customMem customMem) +{ + ZSTD_DCtx *dctx; + + if (!customMem.customAlloc || !customMem.customFree) + return NULL; + + dctx = (ZSTD_DCtx *)ZSTD_malloc(sizeof(ZSTD_DCtx), customMem); + if (!dctx) + return NULL; + memcpy(&dctx->customMem, &customMem, sizeof(customMem)); + ZSTD_decompressBegin(dctx); + return dctx; +} + +ZSTD_DCtx *ZSTD_initDCtx(void *workspace, size_t workspaceSize) +{ + ZSTD_customMem const stackMem = ZSTD_initStack(workspace, workspaceSize); + return ZSTD_createDCtx_advanced(stackMem); +} + +size_t ZSTD_freeDCtx(ZSTD_DCtx *dctx) +{ + if (dctx == NULL) + return 0; /* support free on NULL */ + ZSTD_free(dctx, dctx->customMem); + return 0; /* reserved as a potential error code in the future */ +} + +void ZSTD_copyDCtx(ZSTD_DCtx *dstDCtx, const ZSTD_DCtx *srcDCtx) +{ + size_t const workSpaceSize = (ZSTD_BLOCKSIZE_ABSOLUTEMAX + WILDCOPY_OVERLENGTH) + ZSTD_frameHeaderSize_max; + memcpy(dstDCtx, srcDCtx, sizeof(ZSTD_DCtx) - workSpaceSize); /* no need to copy workspace */ +} + +static void ZSTD_refDDict(ZSTD_DCtx *dstDCtx, const ZSTD_DDict *ddict); + +/*-************************************************************* +* Decompression section +***************************************************************/ + +/*! ZSTD_isFrame() : + * Tells if the content of `buffer` starts with a valid Frame Identifier. + * Note : Frame Identifier is 4 bytes. If `size < 4`, @return will always be 0. + * Note 2 : Legacy Frame Identifiers are considered valid only if Legacy Support is enabled. + * Note 3 : Skippable Frame Identifiers are considered valid. */ +unsigned ZSTD_isFrame(const void *buffer, size_t size) +{ + if (size < 4) + return 0; + { + U32 const magic = ZSTD_readLE32(buffer); + if (magic == ZSTD_MAGICNUMBER) + return 1; + if ((magic & 0xFFFFFFF0U) == ZSTD_MAGIC_SKIPPABLE_START) + return 1; + } + return 0; +} + +/** ZSTD_frameHeaderSize() : +* srcSize must be >= ZSTD_frameHeaderSize_prefix. +* @return : size of the Frame Header */ +static size_t ZSTD_frameHeaderSize(const void *src, size_t srcSize) +{ + if (srcSize < ZSTD_frameHeaderSize_prefix) + return ERROR(srcSize_wrong); + { + BYTE const fhd = ((const BYTE *)src)[4]; + U32 const dictID = fhd & 3; + U32 const singleSegment = (fhd >> 5) & 1; + U32 const fcsId = fhd >> 6; + return ZSTD_frameHeaderSize_prefix + !singleSegment + ZSTD_did_fieldSize[dictID] + ZSTD_fcs_fieldSize[fcsId] + (singleSegment && !fcsId); + } +} + +/** ZSTD_getFrameParams() : +* decode Frame Header, or require larger `srcSize`. +* @return : 0, `fparamsPtr` is correctly filled, +* >0, `srcSize` is too small, result is expected `srcSize`, +* or an error code, which can be tested using ZSTD_isError() */ +size_t ZSTD_getFrameParams(ZSTD_frameParams *fparamsPtr, const void *src, size_t srcSize) +{ + const BYTE *ip = (const BYTE *)src; + + if (srcSize < ZSTD_frameHeaderSize_prefix) + return ZSTD_frameHeaderSize_prefix; + if (ZSTD_readLE32(src) != ZSTD_MAGICNUMBER) { + if ((ZSTD_readLE32(src) & 0xFFFFFFF0U) == ZSTD_MAGIC_SKIPPABLE_START) { + if (srcSize < ZSTD_skippableHeaderSize) + return ZSTD_skippableHeaderSize; /* magic number + skippable frame length */ + memset(fparamsPtr, 0, sizeof(*fparamsPtr)); + fparamsPtr->frameContentSize = ZSTD_readLE32((const char *)src + 4); + fparamsPtr->windowSize = 0; /* windowSize==0 means a frame is skippable */ + return 0; + } + return ERROR(prefix_unknown); + } + + /* ensure there is enough `srcSize` to fully read/decode frame header */ + { + size_t const fhsize = ZSTD_frameHeaderSize(src, srcSize); + if (srcSize < fhsize) + return fhsize; + } + + { + BYTE const fhdByte = ip[4]; + size_t pos = 5; + U32 const dictIDSizeCode = fhdByte & 3; + U32 const checksumFlag = (fhdByte >> 2) & 1; + U32 const singleSegment = (fhdByte >> 5) & 1; + U32 const fcsID = fhdByte >> 6; + U32 const windowSizeMax = 1U << ZSTD_WINDOWLOG_MAX; + U32 windowSize = 0; + U32 dictID = 0; + U64 frameContentSize = 0; + if ((fhdByte & 0x08) != 0) + return ERROR(frameParameter_unsupported); /* reserved bits, which must be zero */ + if (!singleSegment) { + BYTE const wlByte = ip[pos++]; + U32 const windowLog = (wlByte >> 3) + ZSTD_WINDOWLOG_ABSOLUTEMIN; + if (windowLog > ZSTD_WINDOWLOG_MAX) + return ERROR(frameParameter_windowTooLarge); /* avoids issue with 1 << windowLog */ + windowSize = (1U << windowLog); + windowSize += (windowSize >> 3) * (wlByte & 7); + } + + switch (dictIDSizeCode) { + default: /* impossible */ + case 0: break; + case 1: + dictID = ip[pos]; + pos++; + break; + case 2: + dictID = ZSTD_readLE16(ip + pos); + pos += 2; + break; + case 3: + dictID = ZSTD_readLE32(ip + pos); + pos += 4; + break; + } + switch (fcsID) { + default: /* impossible */ + case 0: + if (singleSegment) + frameContentSize = ip[pos]; + break; + case 1: frameContentSize = ZSTD_readLE16(ip + pos) + 256; break; + case 2: frameContentSize = ZSTD_readLE32(ip + pos); break; + case 3: frameContentSize = ZSTD_readLE64(ip + pos); break; + } + if (!windowSize) + windowSize = (U32)frameContentSize; + if (windowSize > windowSizeMax) + return ERROR(frameParameter_windowTooLarge); + fparamsPtr->frameContentSize = frameContentSize; + fparamsPtr->windowSize = windowSize; + fparamsPtr->dictID = dictID; + fparamsPtr->checksumFlag = checksumFlag; + } + return 0; +} + +/** ZSTD_getFrameContentSize() : +* compatible with legacy mode +* @return : decompressed size of the single frame pointed to be `src` if known, otherwise +* - ZSTD_CONTENTSIZE_UNKNOWN if the size cannot be determined +* - ZSTD_CONTENTSIZE_ERROR if an error occurred (e.g. invalid magic number, srcSize too small) */ +unsigned long long ZSTD_getFrameContentSize(const void *src, size_t srcSize) +{ + { + ZSTD_frameParams fParams; + if (ZSTD_getFrameParams(&fParams, src, srcSize) != 0) + return ZSTD_CONTENTSIZE_ERROR; + if (fParams.windowSize == 0) { + /* Either skippable or empty frame, size == 0 either way */ + return 0; + } else if (fParams.frameContentSize != 0) { + return fParams.frameContentSize; + } else { + return ZSTD_CONTENTSIZE_UNKNOWN; + } + } +} + +/** ZSTD_findDecompressedSize() : + * compatible with legacy mode + * `srcSize` must be the exact length of some number of ZSTD compressed and/or + * skippable frames + * @return : decompressed size of the frames contained */ +unsigned long long ZSTD_findDecompressedSize(const void *src, size_t srcSize) +{ + { + unsigned long long totalDstSize = 0; + while (srcSize >= ZSTD_frameHeaderSize_prefix) { + const U32 magicNumber = ZSTD_readLE32(src); + + if ((magicNumber & 0xFFFFFFF0U) == ZSTD_MAGIC_SKIPPABLE_START) { + size_t skippableSize; + if (srcSize < ZSTD_skippableHeaderSize) + return ERROR(srcSize_wrong); + skippableSize = ZSTD_readLE32((const BYTE *)src + 4) + ZSTD_skippableHeaderSize; + if (srcSize < skippableSize) { + return ZSTD_CONTENTSIZE_ERROR; + } + + src = (const BYTE *)src + skippableSize; + srcSize -= skippableSize; + continue; + } + + { + unsigned long long const ret = ZSTD_getFrameContentSize(src, srcSize); + if (ret >= ZSTD_CONTENTSIZE_ERROR) + return ret; + + /* check for overflow */ + if (totalDstSize + ret < totalDstSize) + return ZSTD_CONTENTSIZE_ERROR; + totalDstSize += ret; + } + { + size_t const frameSrcSize = ZSTD_findFrameCompressedSize(src, srcSize); + if (ZSTD_isError(frameSrcSize)) { + return ZSTD_CONTENTSIZE_ERROR; + } + + src = (const BYTE *)src + frameSrcSize; + srcSize -= frameSrcSize; + } + } + + if (srcSize) { + return ZSTD_CONTENTSIZE_ERROR; + } + + return totalDstSize; + } +} + +/** ZSTD_decodeFrameHeader() : +* `headerSize` must be the size provided by ZSTD_frameHeaderSize(). +* @return : 0 if success, or an error code, which can be tested using ZSTD_isError() */ +static size_t ZSTD_decodeFrameHeader(ZSTD_DCtx *dctx, const void *src, size_t headerSize) +{ + size_t const result = ZSTD_getFrameParams(&(dctx->fParams), src, headerSize); + if (ZSTD_isError(result)) + return result; /* invalid header */ + if (result > 0) + return ERROR(srcSize_wrong); /* headerSize too small */ + if (dctx->fParams.dictID && (dctx->dictID != dctx->fParams.dictID)) + return ERROR(dictionary_wrong); + if (dctx->fParams.checksumFlag) + xxh64_reset(&dctx->xxhState, 0); + return 0; +} + +typedef struct { + blockType_e blockType; + U32 lastBlock; + U32 origSize; +} blockProperties_t; + +/*! ZSTD_getcBlockSize() : +* Provides the size of compressed block from block header `src` */ +size_t ZSTD_getcBlockSize(const void *src, size_t srcSize, blockProperties_t *bpPtr) +{ + if (srcSize < ZSTD_blockHeaderSize) + return ERROR(srcSize_wrong); + { + U32 const cBlockHeader = ZSTD_readLE24(src); + U32 const cSize = cBlockHeader >> 3; + bpPtr->lastBlock = cBlockHeader & 1; + bpPtr->blockType = (blockType_e)((cBlockHeader >> 1) & 3); + bpPtr->origSize = cSize; /* only useful for RLE */ + if (bpPtr->blockType == bt_rle) + return 1; + if (bpPtr->blockType == bt_reserved) + return ERROR(corruption_detected); + return cSize; + } +} + +static size_t ZSTD_copyRawBlock(void *dst, size_t dstCapacity, const void *src, size_t srcSize) +{ + if (srcSize > dstCapacity) + return ERROR(dstSize_tooSmall); + memcpy(dst, src, srcSize); + return srcSize; +} + +static size_t ZSTD_setRleBlock(void *dst, size_t dstCapacity, const void *src, size_t srcSize, size_t regenSize) +{ + if (srcSize != 1) + return ERROR(srcSize_wrong); + if (regenSize > dstCapacity) + return ERROR(dstSize_tooSmall); + memset(dst, *(const BYTE *)src, regenSize); + return regenSize; +} + +/*! ZSTD_decodeLiteralsBlock() : + @return : nb of bytes read from src (< srcSize ) */ +size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx *dctx, const void *src, size_t srcSize) /* note : srcSize < BLOCKSIZE */ +{ + if (srcSize < MIN_CBLOCK_SIZE) + return ERROR(corruption_detected); + + { + const BYTE *const istart = (const BYTE *)src; + symbolEncodingType_e const litEncType = (symbolEncodingType_e)(istart[0] & 3); + + switch (litEncType) { + case set_repeat: + if (dctx->litEntropy == 0) + return ERROR(dictionary_corrupted); + /* fall-through */ + case set_compressed: + if (srcSize < 5) + return ERROR(corruption_detected); /* srcSize >= MIN_CBLOCK_SIZE == 3; here we need up to 5 for case 3 */ + { + size_t lhSize, litSize, litCSize; + U32 singleStream = 0; + U32 const lhlCode = (istart[0] >> 2) & 3; + U32 const lhc = ZSTD_readLE32(istart); + switch (lhlCode) { + case 0: + case 1: + default: /* note : default is impossible, since lhlCode into [0..3] */ + /* 2 - 2 - 10 - 10 */ + singleStream = !lhlCode; + lhSize = 3; + litSize = (lhc >> 4) & 0x3FF; + litCSize = (lhc >> 14) & 0x3FF; + break; + case 2: + /* 2 - 2 - 14 - 14 */ + lhSize = 4; + litSize = (lhc >> 4) & 0x3FFF; + litCSize = lhc >> 18; + break; + case 3: + /* 2 - 2 - 18 - 18 */ + lhSize = 5; + litSize = (lhc >> 4) & 0x3FFFF; + litCSize = (lhc >> 22) + (istart[4] << 10); + break; + } + if (litSize > ZSTD_BLOCKSIZE_ABSOLUTEMAX) + return ERROR(corruption_detected); + if (litCSize + lhSize > srcSize) + return ERROR(corruption_detected); + + if (HUF_isError( + (litEncType == set_repeat) + ? (singleStream ? HUF_decompress1X_usingDTable(dctx->litBuffer, litSize, istart + lhSize, litCSize, dctx->HUFptr) + : HUF_decompress4X_usingDTable(dctx->litBuffer, litSize, istart + lhSize, litCSize, dctx->HUFptr)) + : (singleStream + ? HUF_decompress1X2_DCtx_wksp(dctx->entropy.hufTable, dctx->litBuffer, litSize, istart + lhSize, litCSize, + dctx->entropy.workspace, sizeof(dctx->entropy.workspace)) + : HUF_decompress4X_hufOnly_wksp(dctx->entropy.hufTable, dctx->litBuffer, litSize, istart + lhSize, litCSize, + dctx->entropy.workspace, sizeof(dctx->entropy.workspace))))) + return ERROR(corruption_detected); + + dctx->litPtr = dctx->litBuffer; + dctx->litSize = litSize; + dctx->litEntropy = 1; + if (litEncType == set_compressed) + dctx->HUFptr = dctx->entropy.hufTable; + memset(dctx->litBuffer + dctx->litSize, 0, WILDCOPY_OVERLENGTH); + return litCSize + lhSize; + } + + case set_basic: { + size_t litSize, lhSize; + U32 const lhlCode = ((istart[0]) >> 2) & 3; + switch (lhlCode) { + case 0: + case 2: + default: /* note : default is impossible, since lhlCode into [0..3] */ + lhSize = 1; + litSize = istart[0] >> 3; + break; + case 1: + lhSize = 2; + litSize = ZSTD_readLE16(istart) >> 4; + break; + case 3: + lhSize = 3; + litSize = ZSTD_readLE24(istart) >> 4; + break; + } + + if (lhSize + litSize + WILDCOPY_OVERLENGTH > srcSize) { /* risk reading beyond src buffer with wildcopy */ + if (litSize + lhSize > srcSize) + return ERROR(corruption_detected); + memcpy(dctx->litBuffer, istart + lhSize, litSize); + dctx->litPtr = dctx->litBuffer; + dctx->litSize = litSize; + memset(dctx->litBuffer + dctx->litSize, 0, WILDCOPY_OVERLENGTH); + return lhSize + litSize; + } + /* direct reference into compressed stream */ + dctx->litPtr = istart + lhSize; + dctx->litSize = litSize; + return lhSize + litSize; + } + + case set_rle: { + U32 const lhlCode = ((istart[0]) >> 2) & 3; + size_t litSize, lhSize; + switch (lhlCode) { + case 0: + case 2: + default: /* note : default is impossible, since lhlCode into [0..3] */ + lhSize = 1; + litSize = istart[0] >> 3; + break; + case 1: + lhSize = 2; + litSize = ZSTD_readLE16(istart) >> 4; + break; + case 3: + lhSize = 3; + litSize = ZSTD_readLE24(istart) >> 4; + if (srcSize < 4) + return ERROR(corruption_detected); /* srcSize >= MIN_CBLOCK_SIZE == 3; here we need lhSize+1 = 4 */ + break; + } + if (litSize > ZSTD_BLOCKSIZE_ABSOLUTEMAX) + return ERROR(corruption_detected); + memset(dctx->litBuffer, istart[lhSize], litSize + WILDCOPY_OVERLENGTH); + dctx->litPtr = dctx->litBuffer; + dctx->litSize = litSize; + return lhSize + 1; + } + default: + return ERROR(corruption_detected); /* impossible */ + } + } +} + +typedef union { + FSE_decode_t realData; + U32 alignedBy4; +} FSE_decode_t4; + +static const FSE_decode_t4 LL_defaultDTable[(1 << LL_DEFAULTNORMLOG) + 1] = { + {{LL_DEFAULTNORMLOG, 1, 1}}, /* header : tableLog, fastMode, fastMode */ + {{0, 0, 4}}, /* 0 : base, symbol, bits */ + {{16, 0, 4}}, + {{32, 1, 5}}, + {{0, 3, 5}}, + {{0, 4, 5}}, + {{0, 6, 5}}, + {{0, 7, 5}}, + {{0, 9, 5}}, + {{0, 10, 5}}, + {{0, 12, 5}}, + {{0, 14, 6}}, + {{0, 16, 5}}, + {{0, 18, 5}}, + {{0, 19, 5}}, + {{0, 21, 5}}, + {{0, 22, 5}}, + {{0, 24, 5}}, + {{32, 25, 5}}, + {{0, 26, 5}}, + {{0, 27, 6}}, + {{0, 29, 6}}, + {{0, 31, 6}}, + {{32, 0, 4}}, + {{0, 1, 4}}, + {{0, 2, 5}}, + {{32, 4, 5}}, + {{0, 5, 5}}, + {{32, 7, 5}}, + {{0, 8, 5}}, + {{32, 10, 5}}, + {{0, 11, 5}}, + {{0, 13, 6}}, + {{32, 16, 5}}, + {{0, 17, 5}}, + {{32, 19, 5}}, + {{0, 20, 5}}, + {{32, 22, 5}}, + {{0, 23, 5}}, + {{0, 25, 4}}, + {{16, 25, 4}}, + {{32, 26, 5}}, + {{0, 28, 6}}, + {{0, 30, 6}}, + {{48, 0, 4}}, + {{16, 1, 4}}, + {{32, 2, 5}}, + {{32, 3, 5}}, + {{32, 5, 5}}, + {{32, 6, 5}}, + {{32, 8, 5}}, + {{32, 9, 5}}, + {{32, 11, 5}}, + {{32, 12, 5}}, + {{0, 15, 6}}, + {{32, 17, 5}}, + {{32, 18, 5}}, + {{32, 20, 5}}, + {{32, 21, 5}}, + {{32, 23, 5}}, + {{32, 24, 5}}, + {{0, 35, 6}}, + {{0, 34, 6}}, + {{0, 33, 6}}, + {{0, 32, 6}}, +}; /* LL_defaultDTable */ + +static const FSE_decode_t4 ML_defaultDTable[(1 << ML_DEFAULTNORMLOG) + 1] = { + {{ML_DEFAULTNORMLOG, 1, 1}}, /* header : tableLog, fastMode, fastMode */ + {{0, 0, 6}}, /* 0 : base, symbol, bits */ + {{0, 1, 4}}, + {{32, 2, 5}}, + {{0, 3, 5}}, + {{0, 5, 5}}, + {{0, 6, 5}}, + {{0, 8, 5}}, + {{0, 10, 6}}, + {{0, 13, 6}}, + {{0, 16, 6}}, + {{0, 19, 6}}, + {{0, 22, 6}}, + {{0, 25, 6}}, + {{0, 28, 6}}, + {{0, 31, 6}}, + {{0, 33, 6}}, + {{0, 35, 6}}, + {{0, 37, 6}}, + {{0, 39, 6}}, + {{0, 41, 6}}, + {{0, 43, 6}}, + {{0, 45, 6}}, + {{16, 1, 4}}, + {{0, 2, 4}}, + {{32, 3, 5}}, + {{0, 4, 5}}, + {{32, 6, 5}}, + {{0, 7, 5}}, + {{0, 9, 6}}, + {{0, 12, 6}}, + {{0, 15, 6}}, + {{0, 18, 6}}, + {{0, 21, 6}}, + {{0, 24, 6}}, + {{0, 27, 6}}, + {{0, 30, 6}}, + {{0, 32, 6}}, + {{0, 34, 6}}, + {{0, 36, 6}}, + {{0, 38, 6}}, + {{0, 40, 6}}, + {{0, 42, 6}}, + {{0, 44, 6}}, + {{32, 1, 4}}, + {{48, 1, 4}}, + {{16, 2, 4}}, + {{32, 4, 5}}, + {{32, 5, 5}}, + {{32, 7, 5}}, + {{32, 8, 5}}, + {{0, 11, 6}}, + {{0, 14, 6}}, + {{0, 17, 6}}, + {{0, 20, 6}}, + {{0, 23, 6}}, + {{0, 26, 6}}, + {{0, 29, 6}}, + {{0, 52, 6}}, + {{0, 51, 6}}, + {{0, 50, 6}}, + {{0, 49, 6}}, + {{0, 48, 6}}, + {{0, 47, 6}}, + {{0, 46, 6}}, +}; /* ML_defaultDTable */ + +static const FSE_decode_t4 OF_defaultDTable[(1 << OF_DEFAULTNORMLOG) + 1] = { + {{OF_DEFAULTNORMLOG, 1, 1}}, /* header : tableLog, fastMode, fastMode */ + {{0, 0, 5}}, /* 0 : base, symbol, bits */ + {{0, 6, 4}}, + {{0, 9, 5}}, + {{0, 15, 5}}, + {{0, 21, 5}}, + {{0, 3, 5}}, + {{0, 7, 4}}, + {{0, 12, 5}}, + {{0, 18, 5}}, + {{0, 23, 5}}, + {{0, 5, 5}}, + {{0, 8, 4}}, + {{0, 14, 5}}, + {{0, 20, 5}}, + {{0, 2, 5}}, + {{16, 7, 4}}, + {{0, 11, 5}}, + {{0, 17, 5}}, + {{0, 22, 5}}, + {{0, 4, 5}}, + {{16, 8, 4}}, + {{0, 13, 5}}, + {{0, 19, 5}}, + {{0, 1, 5}}, + {{16, 6, 4}}, + {{0, 10, 5}}, + {{0, 16, 5}}, + {{0, 28, 5}}, + {{0, 27, 5}}, + {{0, 26, 5}}, + {{0, 25, 5}}, + {{0, 24, 5}}, +}; /* OF_defaultDTable */ + +/*! ZSTD_buildSeqTable() : + @return : nb bytes read from src, + or an error code if it fails, testable with ZSTD_isError() +*/ +static size_t ZSTD_buildSeqTable(FSE_DTable *DTableSpace, const FSE_DTable **DTablePtr, symbolEncodingType_e type, U32 max, U32 maxLog, const void *src, + size_t srcSize, const FSE_decode_t4 *defaultTable, U32 flagRepeatTable, void *workspace, size_t workspaceSize) +{ + const void *const tmpPtr = defaultTable; /* bypass strict aliasing */ + switch (type) { + case set_rle: + if (!srcSize) + return ERROR(srcSize_wrong); + if ((*(const BYTE *)src) > max) + return ERROR(corruption_detected); + FSE_buildDTable_rle(DTableSpace, *(const BYTE *)src); + *DTablePtr = DTableSpace; + return 1; + case set_basic: *DTablePtr = (const FSE_DTable *)tmpPtr; return 0; + case set_repeat: + if (!flagRepeatTable) + return ERROR(corruption_detected); + return 0; + default: /* impossible */ + case set_compressed: { + U32 tableLog; + S16 *norm = (S16 *)workspace; + size_t const spaceUsed32 = ALIGN(sizeof(S16) * (MaxSeq + 1), sizeof(U32)) >> 2; + + if ((spaceUsed32 << 2) > workspaceSize) + return ERROR(GENERIC); + workspace = (U32 *)workspace + spaceUsed32; + workspaceSize -= (spaceUsed32 << 2); + { + size_t const headerSize = FSE_readNCount(norm, &max, &tableLog, src, srcSize); + if (FSE_isError(headerSize)) + return ERROR(corruption_detected); + if (tableLog > maxLog) + return ERROR(corruption_detected); + FSE_buildDTable_wksp(DTableSpace, norm, max, tableLog, workspace, workspaceSize); + *DTablePtr = DTableSpace; + return headerSize; + } + } + } +} + +size_t ZSTD_decodeSeqHeaders(ZSTD_DCtx *dctx, int *nbSeqPtr, const void *src, size_t srcSize) +{ + const BYTE *const istart = (const BYTE *const)src; + const BYTE *const iend = istart + srcSize; + const BYTE *ip = istart; + + /* check */ + if (srcSize < MIN_SEQUENCES_SIZE) + return ERROR(srcSize_wrong); + + /* SeqHead */ + { + int nbSeq = *ip++; + if (!nbSeq) { + *nbSeqPtr = 0; + return 1; + } + if (nbSeq > 0x7F) { + if (nbSeq == 0xFF) { + if (ip + 2 > iend) + return ERROR(srcSize_wrong); + nbSeq = ZSTD_readLE16(ip) + LONGNBSEQ, ip += 2; + } else { + if (ip >= iend) + return ERROR(srcSize_wrong); + nbSeq = ((nbSeq - 0x80) << 8) + *ip++; + } + } + *nbSeqPtr = nbSeq; + } + + /* FSE table descriptors */ + if (ip + 4 > iend) + return ERROR(srcSize_wrong); /* minimum possible size */ + { + symbolEncodingType_e const LLtype = (symbolEncodingType_e)(*ip >> 6); + symbolEncodingType_e const OFtype = (symbolEncodingType_e)((*ip >> 4) & 3); + symbolEncodingType_e const MLtype = (symbolEncodingType_e)((*ip >> 2) & 3); + ip++; + + /* Build DTables */ + { + size_t const llhSize = ZSTD_buildSeqTable(dctx->entropy.LLTable, &dctx->LLTptr, LLtype, MaxLL, LLFSELog, ip, iend - ip, + LL_defaultDTable, dctx->fseEntropy, dctx->entropy.workspace, sizeof(dctx->entropy.workspace)); + if (ZSTD_isError(llhSize)) + return ERROR(corruption_detected); + ip += llhSize; + } + { + size_t const ofhSize = ZSTD_buildSeqTable(dctx->entropy.OFTable, &dctx->OFTptr, OFtype, MaxOff, OffFSELog, ip, iend - ip, + OF_defaultDTable, dctx->fseEntropy, dctx->entropy.workspace, sizeof(dctx->entropy.workspace)); + if (ZSTD_isError(ofhSize)) + return ERROR(corruption_detected); + ip += ofhSize; + } + { + size_t const mlhSize = ZSTD_buildSeqTable(dctx->entropy.MLTable, &dctx->MLTptr, MLtype, MaxML, MLFSELog, ip, iend - ip, + ML_defaultDTable, dctx->fseEntropy, dctx->entropy.workspace, sizeof(dctx->entropy.workspace)); + if (ZSTD_isError(mlhSize)) + return ERROR(corruption_detected); + ip += mlhSize; + } + } + + return ip - istart; +} + +typedef struct { + size_t litLength; + size_t matchLength; + size_t offset; + const BYTE *match; +} seq_t; + +typedef struct { + BIT_DStream_t DStream; + FSE_DState_t stateLL; + FSE_DState_t stateOffb; + FSE_DState_t stateML; + size_t prevOffset[ZSTD_REP_NUM]; + const BYTE *base; + size_t pos; + uPtrDiff gotoDict; +} seqState_t; + +FORCE_NOINLINE +size_t ZSTD_execSequenceLast7(BYTE *op, BYTE *const oend, seq_t sequence, const BYTE **litPtr, const BYTE *const litLimit, const BYTE *const base, + const BYTE *const vBase, const BYTE *const dictEnd) +{ + BYTE *const oLitEnd = op + sequence.litLength; + size_t const sequenceLength = sequence.litLength + sequence.matchLength; + BYTE *const oMatchEnd = op + sequenceLength; /* risk : address space overflow (32-bits) */ + BYTE *const oend_w = oend - WILDCOPY_OVERLENGTH; + const BYTE *const iLitEnd = *litPtr + sequence.litLength; + const BYTE *match = oLitEnd - sequence.offset; + + /* check */ + if (oMatchEnd > oend) + return ERROR(dstSize_tooSmall); /* last match must start at a minimum distance of WILDCOPY_OVERLENGTH from oend */ + if (iLitEnd > litLimit) + return ERROR(corruption_detected); /* over-read beyond lit buffer */ + if (oLitEnd <= oend_w) + return ERROR(GENERIC); /* Precondition */ + + /* copy literals */ + if (op < oend_w) { + ZSTD_wildcopy(op, *litPtr, oend_w - op); + *litPtr += oend_w - op; + op = oend_w; + } + while (op < oLitEnd) + *op++ = *(*litPtr)++; + + /* copy Match */ + if (sequence.offset > (size_t)(oLitEnd - base)) { + /* offset beyond prefix */ + if (sequence.offset > (size_t)(oLitEnd - vBase)) + return ERROR(corruption_detected); + match = dictEnd - (base - match); + if (match + sequence.matchLength <= dictEnd) { + memmove(oLitEnd, match, sequence.matchLength); + return sequenceLength; + } + /* span extDict & currPrefixSegment */ + { + size_t const length1 = dictEnd - match; + memmove(oLitEnd, match, length1); + op = oLitEnd + length1; + sequence.matchLength -= length1; + match = base; + } + } + while (op < oMatchEnd) + *op++ = *match++; + return sequenceLength; +} + +static seq_t ZSTD_decodeSequence(seqState_t *seqState) +{ + seq_t seq; + + U32 const llCode = FSE_peekSymbol(&seqState->stateLL); + U32 const mlCode = FSE_peekSymbol(&seqState->stateML); + U32 const ofCode = FSE_peekSymbol(&seqState->stateOffb); /* <= maxOff, by table construction */ + + U32 const llBits = LL_bits[llCode]; + U32 const mlBits = ML_bits[mlCode]; + U32 const ofBits = ofCode; + U32 const totalBits = llBits + mlBits + ofBits; + + static const U32 LL_base[MaxLL + 1] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, + 20, 22, 24, 28, 32, 40, 48, 64, 0x80, 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000, 0x10000}; + + static const U32 ML_base[MaxML + 1] = {3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 39, 41, + 43, 47, 51, 59, 67, 83, 99, 0x83, 0x103, 0x203, 0x403, 0x803, 0x1003, 0x2003, 0x4003, 0x8003, 0x10003}; + + static const U32 OF_base[MaxOff + 1] = {0, 1, 1, 5, 0xD, 0x1D, 0x3D, 0x7D, 0xFD, 0x1FD, + 0x3FD, 0x7FD, 0xFFD, 0x1FFD, 0x3FFD, 0x7FFD, 0xFFFD, 0x1FFFD, 0x3FFFD, 0x7FFFD, + 0xFFFFD, 0x1FFFFD, 0x3FFFFD, 0x7FFFFD, 0xFFFFFD, 0x1FFFFFD, 0x3FFFFFD, 0x7FFFFFD, 0xFFFFFFD}; + + /* sequence */ + { + size_t offset; + if (!ofCode) + offset = 0; + else { + offset = OF_base[ofCode] + BIT_readBitsFast(&seqState->DStream, ofBits); /* <= (ZSTD_WINDOWLOG_MAX-1) bits */ + if (ZSTD_32bits()) + BIT_reloadDStream(&seqState->DStream); + } + + if (ofCode <= 1) { + offset += (llCode == 0); + if (offset) { + size_t temp = (offset == 3) ? seqState->prevOffset[0] - 1 : seqState->prevOffset[offset]; + temp += !temp; /* 0 is not valid; input is corrupted; force offset to 1 */ + if (offset != 1) + seqState->prevOffset[2] = seqState->prevOffset[1]; + seqState->prevOffset[1] = seqState->prevOffset[0]; + seqState->prevOffset[0] = offset = temp; + } else { + offset = seqState->prevOffset[0]; + } + } else { + seqState->prevOffset[2] = seqState->prevOffset[1]; + seqState->prevOffset[1] = seqState->prevOffset[0]; + seqState->prevOffset[0] = offset; + } + seq.offset = offset; + } + + seq.matchLength = ML_base[mlCode] + ((mlCode > 31) ? BIT_readBitsFast(&seqState->DStream, mlBits) : 0); /* <= 16 bits */ + if (ZSTD_32bits() && (mlBits + llBits > 24)) + BIT_reloadDStream(&seqState->DStream); + + seq.litLength = LL_base[llCode] + ((llCode > 15) ? BIT_readBitsFast(&seqState->DStream, llBits) : 0); /* <= 16 bits */ + if (ZSTD_32bits() || (totalBits > 64 - 7 - (LLFSELog + MLFSELog + OffFSELog))) + BIT_reloadDStream(&seqState->DStream); + + /* ANS state update */ + FSE_updateState(&seqState->stateLL, &seqState->DStream); /* <= 9 bits */ + FSE_updateState(&seqState->stateML, &seqState->DStream); /* <= 9 bits */ + if (ZSTD_32bits()) + BIT_reloadDStream(&seqState->DStream); /* <= 18 bits */ + FSE_updateState(&seqState->stateOffb, &seqState->DStream); /* <= 8 bits */ + + seq.match = NULL; + + return seq; +} + +FORCE_INLINE +size_t ZSTD_execSequence(BYTE *op, BYTE *const oend, seq_t sequence, const BYTE **litPtr, const BYTE *const litLimit, const BYTE *const base, + const BYTE *const vBase, const BYTE *const dictEnd) +{ + BYTE *const oLitEnd = op + sequence.litLength; + size_t const sequenceLength = sequence.litLength + sequence.matchLength; + BYTE *const oMatchEnd = op + sequenceLength; /* risk : address space overflow (32-bits) */ + BYTE *const oend_w = oend - WILDCOPY_OVERLENGTH; + const BYTE *const iLitEnd = *litPtr + sequence.litLength; + const BYTE *match = oLitEnd - sequence.offset; + + /* check */ + if (oMatchEnd > oend) + return ERROR(dstSize_tooSmall); /* last match must start at a minimum distance of WILDCOPY_OVERLENGTH from oend */ + if (iLitEnd > litLimit) + return ERROR(corruption_detected); /* over-read beyond lit buffer */ + if (oLitEnd > oend_w) + return ZSTD_execSequenceLast7(op, oend, sequence, litPtr, litLimit, base, vBase, dictEnd); + + /* copy Literals */ + ZSTD_copy8(op, *litPtr); + if (sequence.litLength > 8) + ZSTD_wildcopy(op + 8, (*litPtr) + 8, + sequence.litLength - 8); /* note : since oLitEnd <= oend-WILDCOPY_OVERLENGTH, no risk of overwrite beyond oend */ + op = oLitEnd; + *litPtr = iLitEnd; /* update for next sequence */ + + /* copy Match */ + if (sequence.offset > (size_t)(oLitEnd - base)) { + /* offset beyond prefix */ + if (sequence.offset > (size_t)(oLitEnd - vBase)) + return ERROR(corruption_detected); + match = dictEnd + (match - base); + if (match + sequence.matchLength <= dictEnd) { + memmove(oLitEnd, match, sequence.matchLength); + return sequenceLength; + } + /* span extDict & currPrefixSegment */ + { + size_t const length1 = dictEnd - match; + memmove(oLitEnd, match, length1); + op = oLitEnd + length1; + sequence.matchLength -= length1; + match = base; + if (op > oend_w || sequence.matchLength < MINMATCH) { + U32 i; + for (i = 0; i < sequence.matchLength; ++i) + op[i] = match[i]; + return sequenceLength; + } + } + } + /* Requirement: op <= oend_w && sequence.matchLength >= MINMATCH */ + + /* match within prefix */ + if (sequence.offset < 8) { + /* close range match, overlap */ + static const U32 dec32table[] = {0, 1, 2, 1, 4, 4, 4, 4}; /* added */ + static const int dec64table[] = {8, 8, 8, 7, 8, 9, 10, 11}; /* subtracted */ + int const sub2 = dec64table[sequence.offset]; + op[0] = match[0]; + op[1] = match[1]; + op[2] = match[2]; + op[3] = match[3]; + match += dec32table[sequence.offset]; + ZSTD_copy4(op + 4, match); + match -= sub2; + } else { + ZSTD_copy8(op, match); + } + op += 8; + match += 8; + + if (oMatchEnd > oend - (16 - MINMATCH)) { + if (op < oend_w) { + ZSTD_wildcopy(op, match, oend_w - op); + match += oend_w - op; + op = oend_w; + } + while (op < oMatchEnd) + *op++ = *match++; + } else { + ZSTD_wildcopy(op, match, (ptrdiff_t)sequence.matchLength - 8); /* works even if matchLength < 8 */ + } + return sequenceLength; +} + +static size_t ZSTD_decompressSequences(ZSTD_DCtx *dctx, void *dst, size_t maxDstSize, const void *seqStart, size_t seqSize) +{ + const BYTE *ip = (const BYTE *)seqStart; + const BYTE *const iend = ip + seqSize; + BYTE *const ostart = (BYTE * const)dst; + BYTE *const oend = ostart + maxDstSize; + BYTE *op = ostart; + const BYTE *litPtr = dctx->litPtr; + const BYTE *const litEnd = litPtr + dctx->litSize; + const BYTE *const base = (const BYTE *)(dctx->base); + const BYTE *const vBase = (const BYTE *)(dctx->vBase); + const BYTE *const dictEnd = (const BYTE *)(dctx->dictEnd); + int nbSeq; + + /* Build Decoding Tables */ + { + size_t const seqHSize = ZSTD_decodeSeqHeaders(dctx, &nbSeq, ip, seqSize); + if (ZSTD_isError(seqHSize)) + return seqHSize; + ip += seqHSize; + } + + /* Regen sequences */ + if (nbSeq) { + seqState_t seqState; + dctx->fseEntropy = 1; + { + U32 i; + for (i = 0; i < ZSTD_REP_NUM; i++) + seqState.prevOffset[i] = dctx->entropy.rep[i]; + } + CHECK_E(BIT_initDStream(&seqState.DStream, ip, iend - ip), corruption_detected); + FSE_initDState(&seqState.stateLL, &seqState.DStream, dctx->LLTptr); + FSE_initDState(&seqState.stateOffb, &seqState.DStream, dctx->OFTptr); + FSE_initDState(&seqState.stateML, &seqState.DStream, dctx->MLTptr); + + for (; (BIT_reloadDStream(&(seqState.DStream)) <= BIT_DStream_completed) && nbSeq;) { + nbSeq--; + { + seq_t const sequence = ZSTD_decodeSequence(&seqState); + size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequence, &litPtr, litEnd, base, vBase, dictEnd); + if (ZSTD_isError(oneSeqSize)) + return oneSeqSize; + op += oneSeqSize; + } + } + + /* check if reached exact end */ + if (nbSeq) + return ERROR(corruption_detected); + /* save reps for next block */ + { + U32 i; + for (i = 0; i < ZSTD_REP_NUM; i++) + dctx->entropy.rep[i] = (U32)(seqState.prevOffset[i]); + } + } + + /* last literal segment */ + { + size_t const lastLLSize = litEnd - litPtr; + if (lastLLSize > (size_t)(oend - op)) + return ERROR(dstSize_tooSmall); + memcpy(op, litPtr, lastLLSize); + op += lastLLSize; + } + + return op - ostart; +} + +FORCE_INLINE seq_t ZSTD_decodeSequenceLong_generic(seqState_t *seqState, int const longOffsets) +{ + seq_t seq; + + U32 const llCode = FSE_peekSymbol(&seqState->stateLL); + U32 const mlCode = FSE_peekSymbol(&seqState->stateML); + U32 const ofCode = FSE_peekSymbol(&seqState->stateOffb); /* <= maxOff, by table construction */ + + U32 const llBits = LL_bits[llCode]; + U32 const mlBits = ML_bits[mlCode]; + U32 const ofBits = ofCode; + U32 const totalBits = llBits + mlBits + ofBits; + + static const U32 LL_base[MaxLL + 1] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, + 20, 22, 24, 28, 32, 40, 48, 64, 0x80, 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000, 0x10000}; + + static const U32 ML_base[MaxML + 1] = {3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 39, 41, + 43, 47, 51, 59, 67, 83, 99, 0x83, 0x103, 0x203, 0x403, 0x803, 0x1003, 0x2003, 0x4003, 0x8003, 0x10003}; + + static const U32 OF_base[MaxOff + 1] = {0, 1, 1, 5, 0xD, 0x1D, 0x3D, 0x7D, 0xFD, 0x1FD, + 0x3FD, 0x7FD, 0xFFD, 0x1FFD, 0x3FFD, 0x7FFD, 0xFFFD, 0x1FFFD, 0x3FFFD, 0x7FFFD, + 0xFFFFD, 0x1FFFFD, 0x3FFFFD, 0x7FFFFD, 0xFFFFFD, 0x1FFFFFD, 0x3FFFFFD, 0x7FFFFFD, 0xFFFFFFD}; + + /* sequence */ + { + size_t offset; + if (!ofCode) + offset = 0; + else { + if (longOffsets) { + int const extraBits = ofBits - MIN(ofBits, STREAM_ACCUMULATOR_MIN); + offset = OF_base[ofCode] + (BIT_readBitsFast(&seqState->DStream, ofBits - extraBits) << extraBits); + if (ZSTD_32bits() || extraBits) + BIT_reloadDStream(&seqState->DStream); + if (extraBits) + offset += BIT_readBitsFast(&seqState->DStream, extraBits); + } else { + offset = OF_base[ofCode] + BIT_readBitsFast(&seqState->DStream, ofBits); /* <= (ZSTD_WINDOWLOG_MAX-1) bits */ + if (ZSTD_32bits()) + BIT_reloadDStream(&seqState->DStream); + } + } + + if (ofCode <= 1) { + offset += (llCode == 0); + if (offset) { + size_t temp = (offset == 3) ? seqState->prevOffset[0] - 1 : seqState->prevOffset[offset]; + temp += !temp; /* 0 is not valid; input is corrupted; force offset to 1 */ + if (offset != 1) + seqState->prevOffset[2] = seqState->prevOffset[1]; + seqState->prevOffset[1] = seqState->prevOffset[0]; + seqState->prevOffset[0] = offset = temp; + } else { + offset = seqState->prevOffset[0]; + } + } else { + seqState->prevOffset[2] = seqState->prevOffset[1]; + seqState->prevOffset[1] = seqState->prevOffset[0]; + seqState->prevOffset[0] = offset; + } + seq.offset = offset; + } + + seq.matchLength = ML_base[mlCode] + ((mlCode > 31) ? BIT_readBitsFast(&seqState->DStream, mlBits) : 0); /* <= 16 bits */ + if (ZSTD_32bits() && (mlBits + llBits > 24)) + BIT_reloadDStream(&seqState->DStream); + + seq.litLength = LL_base[llCode] + ((llCode > 15) ? BIT_readBitsFast(&seqState->DStream, llBits) : 0); /* <= 16 bits */ + if (ZSTD_32bits() || (totalBits > 64 - 7 - (LLFSELog + MLFSELog + OffFSELog))) + BIT_reloadDStream(&seqState->DStream); + + { + size_t const pos = seqState->pos + seq.litLength; + seq.match = seqState->base + pos - seq.offset; /* single memory segment */ + if (seq.offset > pos) + seq.match += seqState->gotoDict; /* separate memory segment */ + seqState->pos = pos + seq.matchLength; + } + + /* ANS state update */ + FSE_updateState(&seqState->stateLL, &seqState->DStream); /* <= 9 bits */ + FSE_updateState(&seqState->stateML, &seqState->DStream); /* <= 9 bits */ + if (ZSTD_32bits()) + BIT_reloadDStream(&seqState->DStream); /* <= 18 bits */ + FSE_updateState(&seqState->stateOffb, &seqState->DStream); /* <= 8 bits */ + + return seq; +} + +static seq_t ZSTD_decodeSequenceLong(seqState_t *seqState, unsigned const windowSize) +{ + if (ZSTD_highbit32(windowSize) > STREAM_ACCUMULATOR_MIN) { + return ZSTD_decodeSequenceLong_generic(seqState, 1); + } else { + return ZSTD_decodeSequenceLong_generic(seqState, 0); + } +} + +FORCE_INLINE +size_t ZSTD_execSequenceLong(BYTE *op, BYTE *const oend, seq_t sequence, const BYTE **litPtr, const BYTE *const litLimit, const BYTE *const base, + const BYTE *const vBase, const BYTE *const dictEnd) +{ + BYTE *const oLitEnd = op + sequence.litLength; + size_t const sequenceLength = sequence.litLength + sequence.matchLength; + BYTE *const oMatchEnd = op + sequenceLength; /* risk : address space overflow (32-bits) */ + BYTE *const oend_w = oend - WILDCOPY_OVERLENGTH; + const BYTE *const iLitEnd = *litPtr + sequence.litLength; + const BYTE *match = sequence.match; + + /* check */ + if (oMatchEnd > oend) + return ERROR(dstSize_tooSmall); /* last match must start at a minimum distance of WILDCOPY_OVERLENGTH from oend */ + if (iLitEnd > litLimit) + return ERROR(corruption_detected); /* over-read beyond lit buffer */ + if (oLitEnd > oend_w) + return ZSTD_execSequenceLast7(op, oend, sequence, litPtr, litLimit, base, vBase, dictEnd); + + /* copy Literals */ + ZSTD_copy8(op, *litPtr); + if (sequence.litLength > 8) + ZSTD_wildcopy(op + 8, (*litPtr) + 8, + sequence.litLength - 8); /* note : since oLitEnd <= oend-WILDCOPY_OVERLENGTH, no risk of overwrite beyond oend */ + op = oLitEnd; + *litPtr = iLitEnd; /* update for next sequence */ + + /* copy Match */ + if (sequence.offset > (size_t)(oLitEnd - base)) { + /* offset beyond prefix */ + if (sequence.offset > (size_t)(oLitEnd - vBase)) + return ERROR(corruption_detected); + if (match + sequence.matchLength <= dictEnd) { + memmove(oLitEnd, match, sequence.matchLength); + return sequenceLength; + } + /* span extDict & currPrefixSegment */ + { + size_t const length1 = dictEnd - match; + memmove(oLitEnd, match, length1); + op = oLitEnd + length1; + sequence.matchLength -= length1; + match = base; + if (op > oend_w || sequence.matchLength < MINMATCH) { + U32 i; + for (i = 0; i < sequence.matchLength; ++i) + op[i] = match[i]; + return sequenceLength; + } + } + } + /* Requirement: op <= oend_w && sequence.matchLength >= MINMATCH */ + + /* match within prefix */ + if (sequence.offset < 8) { + /* close range match, overlap */ + static const U32 dec32table[] = {0, 1, 2, 1, 4, 4, 4, 4}; /* added */ + static const int dec64table[] = {8, 8, 8, 7, 8, 9, 10, 11}; /* subtracted */ + int const sub2 = dec64table[sequence.offset]; + op[0] = match[0]; + op[1] = match[1]; + op[2] = match[2]; + op[3] = match[3]; + match += dec32table[sequence.offset]; + ZSTD_copy4(op + 4, match); + match -= sub2; + } else { + ZSTD_copy8(op, match); + } + op += 8; + match += 8; + + if (oMatchEnd > oend - (16 - MINMATCH)) { + if (op < oend_w) { + ZSTD_wildcopy(op, match, oend_w - op); + match += oend_w - op; + op = oend_w; + } + while (op < oMatchEnd) + *op++ = *match++; + } else { + ZSTD_wildcopy(op, match, (ptrdiff_t)sequence.matchLength - 8); /* works even if matchLength < 8 */ + } + return sequenceLength; +} + +static size_t ZSTD_decompressSequencesLong(ZSTD_DCtx *dctx, void *dst, size_t maxDstSize, const void *seqStart, size_t seqSize) +{ + const BYTE *ip = (const BYTE *)seqStart; + const BYTE *const iend = ip + seqSize; + BYTE *const ostart = (BYTE * const)dst; + BYTE *const oend = ostart + maxDstSize; + BYTE *op = ostart; + const BYTE *litPtr = dctx->litPtr; + const BYTE *const litEnd = litPtr + dctx->litSize; + const BYTE *const base = (const BYTE *)(dctx->base); + const BYTE *const vBase = (const BYTE *)(dctx->vBase); + const BYTE *const dictEnd = (const BYTE *)(dctx->dictEnd); + unsigned const windowSize = dctx->fParams.windowSize; + int nbSeq; + + /* Build Decoding Tables */ + { + size_t const seqHSize = ZSTD_decodeSeqHeaders(dctx, &nbSeq, ip, seqSize); + if (ZSTD_isError(seqHSize)) + return seqHSize; + ip += seqHSize; + } + + /* Regen sequences */ + if (nbSeq) { +#define STORED_SEQS 4 +#define STOSEQ_MASK (STORED_SEQS - 1) +#define ADVANCED_SEQS 4 + seq_t *sequences = (seq_t *)dctx->entropy.workspace; + int const seqAdvance = MIN(nbSeq, ADVANCED_SEQS); + seqState_t seqState; + int seqNb; + ZSTD_STATIC_ASSERT(sizeof(dctx->entropy.workspace) >= sizeof(seq_t) * STORED_SEQS); + dctx->fseEntropy = 1; + { + U32 i; + for (i = 0; i < ZSTD_REP_NUM; i++) + seqState.prevOffset[i] = dctx->entropy.rep[i]; + } + seqState.base = base; + seqState.pos = (size_t)(op - base); + seqState.gotoDict = (uPtrDiff)dictEnd - (uPtrDiff)base; /* cast to avoid undefined behaviour */ + CHECK_E(BIT_initDStream(&seqState.DStream, ip, iend - ip), corruption_detected); + FSE_initDState(&seqState.stateLL, &seqState.DStream, dctx->LLTptr); + FSE_initDState(&seqState.stateOffb, &seqState.DStream, dctx->OFTptr); + FSE_initDState(&seqState.stateML, &seqState.DStream, dctx->MLTptr); + + /* prepare in advance */ + for (seqNb = 0; (BIT_reloadDStream(&seqState.DStream) <= BIT_DStream_completed) && seqNb < seqAdvance; seqNb++) { + sequences[seqNb] = ZSTD_decodeSequenceLong(&seqState, windowSize); + } + if (seqNb < seqAdvance) + return ERROR(corruption_detected); + + /* decode and decompress */ + for (; (BIT_reloadDStream(&(seqState.DStream)) <= BIT_DStream_completed) && seqNb < nbSeq; seqNb++) { + seq_t const sequence = ZSTD_decodeSequenceLong(&seqState, windowSize); + size_t const oneSeqSize = + ZSTD_execSequenceLong(op, oend, sequences[(seqNb - ADVANCED_SEQS) & STOSEQ_MASK], &litPtr, litEnd, base, vBase, dictEnd); + if (ZSTD_isError(oneSeqSize)) + return oneSeqSize; + ZSTD_PREFETCH(sequence.match); + sequences[seqNb & STOSEQ_MASK] = sequence; + op += oneSeqSize; + } + if (seqNb < nbSeq) + return ERROR(corruption_detected); + + /* finish queue */ + seqNb -= seqAdvance; + for (; seqNb < nbSeq; seqNb++) { + size_t const oneSeqSize = ZSTD_execSequenceLong(op, oend, sequences[seqNb & STOSEQ_MASK], &litPtr, litEnd, base, vBase, dictEnd); + if (ZSTD_isError(oneSeqSize)) + return oneSeqSize; + op += oneSeqSize; + } + + /* save reps for next block */ + { + U32 i; + for (i = 0; i < ZSTD_REP_NUM; i++) + dctx->entropy.rep[i] = (U32)(seqState.prevOffset[i]); + } + } + + /* last literal segment */ + { + size_t const lastLLSize = litEnd - litPtr; + if (lastLLSize > (size_t)(oend - op)) + return ERROR(dstSize_tooSmall); + memcpy(op, litPtr, lastLLSize); + op += lastLLSize; + } + + return op - ostart; +} + +static size_t ZSTD_decompressBlock_internal(ZSTD_DCtx *dctx, void *dst, size_t dstCapacity, const void *src, size_t srcSize) +{ /* blockType == blockCompressed */ + const BYTE *ip = (const BYTE *)src; + + if (srcSize >= ZSTD_BLOCKSIZE_ABSOLUTEMAX) + return ERROR(srcSize_wrong); + + /* Decode literals section */ + { + size_t const litCSize = ZSTD_decodeLiteralsBlock(dctx, src, srcSize); + if (ZSTD_isError(litCSize)) + return litCSize; + ip += litCSize; + srcSize -= litCSize; + } + if (sizeof(size_t) > 4) /* do not enable prefetching on 32-bits x86, as it's performance detrimental */ + /* likely because of register pressure */ + /* if that's the correct cause, then 32-bits ARM should be affected differently */ + /* it would be good to test this on ARM real hardware, to see if prefetch version improves speed */ + if (dctx->fParams.windowSize > (1 << 23)) + return ZSTD_decompressSequencesLong(dctx, dst, dstCapacity, ip, srcSize); + return ZSTD_decompressSequences(dctx, dst, dstCapacity, ip, srcSize); +} + +static void ZSTD_checkContinuity(ZSTD_DCtx *dctx, const void *dst) +{ + if (dst != dctx->previousDstEnd) { /* not contiguous */ + dctx->dictEnd = dctx->previousDstEnd; + dctx->vBase = (const char *)dst - ((const char *)(dctx->previousDstEnd) - (const char *)(dctx->base)); + dctx->base = dst; + dctx->previousDstEnd = dst; + } +} + +size_t ZSTD_decompressBlock(ZSTD_DCtx *dctx, void *dst, size_t dstCapacity, const void *src, size_t srcSize) +{ + size_t dSize; + ZSTD_checkContinuity(dctx, dst); + dSize = ZSTD_decompressBlock_internal(dctx, dst, dstCapacity, src, srcSize); + dctx->previousDstEnd = (char *)dst + dSize; + return dSize; +} + +/** ZSTD_insertBlock() : + insert `src` block into `dctx` history. Useful to track uncompressed blocks. */ +size_t ZSTD_insertBlock(ZSTD_DCtx *dctx, const void *blockStart, size_t blockSize) +{ + ZSTD_checkContinuity(dctx, blockStart); + dctx->previousDstEnd = (const char *)blockStart + blockSize; + return blockSize; +} + +size_t ZSTD_generateNxBytes(void *dst, size_t dstCapacity, BYTE byte, size_t length) +{ + if (length > dstCapacity) + return ERROR(dstSize_tooSmall); + memset(dst, byte, length); + return length; +} + +/** ZSTD_findFrameCompressedSize() : + * compatible with legacy mode + * `src` must point to the start of a ZSTD frame, ZSTD legacy frame, or skippable frame + * `srcSize` must be at least as large as the frame contained + * @return : the compressed size of the frame starting at `src` */ +size_t ZSTD_findFrameCompressedSize(const void *src, size_t srcSize) +{ + if (srcSize >= ZSTD_skippableHeaderSize && (ZSTD_readLE32(src) & 0xFFFFFFF0U) == ZSTD_MAGIC_SKIPPABLE_START) { + return ZSTD_skippableHeaderSize + ZSTD_readLE32((const BYTE *)src + 4); + } else { + const BYTE *ip = (const BYTE *)src; + const BYTE *const ipstart = ip; + size_t remainingSize = srcSize; + ZSTD_frameParams fParams; + + size_t const headerSize = ZSTD_frameHeaderSize(ip, remainingSize); + if (ZSTD_isError(headerSize)) + return headerSize; + + /* Frame Header */ + { + size_t const ret = ZSTD_getFrameParams(&fParams, ip, remainingSize); + if (ZSTD_isError(ret)) + return ret; + if (ret > 0) + return ERROR(srcSize_wrong); + } + + ip += headerSize; + remainingSize -= headerSize; + + /* Loop on each block */ + while (1) { + blockProperties_t blockProperties; + size_t const cBlockSize = ZSTD_getcBlockSize(ip, remainingSize, &blockProperties); + if (ZSTD_isError(cBlockSize)) + return cBlockSize; + + if (ZSTD_blockHeaderSize + cBlockSize > remainingSize) + return ERROR(srcSize_wrong); + + ip += ZSTD_blockHeaderSize + cBlockSize; + remainingSize -= ZSTD_blockHeaderSize + cBlockSize; + + if (blockProperties.lastBlock) + break; + } + + if (fParams.checksumFlag) { /* Frame content checksum */ + if (remainingSize < 4) + return ERROR(srcSize_wrong); + ip += 4; + remainingSize -= 4; + } + + return ip - ipstart; + } +} + +/*! ZSTD_decompressFrame() : +* @dctx must be properly initialized */ +static size_t ZSTD_decompressFrame(ZSTD_DCtx *dctx, void *dst, size_t dstCapacity, const void **srcPtr, size_t *srcSizePtr) +{ + const BYTE *ip = (const BYTE *)(*srcPtr); + BYTE *const ostart = (BYTE * const)dst; + BYTE *const oend = ostart + dstCapacity; + BYTE *op = ostart; + size_t remainingSize = *srcSizePtr; + + /* check */ + if (remainingSize < ZSTD_frameHeaderSize_min + ZSTD_blockHeaderSize) + return ERROR(srcSize_wrong); + + /* Frame Header */ + { + size_t const frameHeaderSize = ZSTD_frameHeaderSize(ip, ZSTD_frameHeaderSize_prefix); + if (ZSTD_isError(frameHeaderSize)) + return frameHeaderSize; + if (remainingSize < frameHeaderSize + ZSTD_blockHeaderSize) + return ERROR(srcSize_wrong); + CHECK_F(ZSTD_decodeFrameHeader(dctx, ip, frameHeaderSize)); + ip += frameHeaderSize; + remainingSize -= frameHeaderSize; + } + + /* Loop on each block */ + while (1) { + size_t decodedSize; + blockProperties_t blockProperties; + size_t const cBlockSize = ZSTD_getcBlockSize(ip, remainingSize, &blockProperties); + if (ZSTD_isError(cBlockSize)) + return cBlockSize; + + ip += ZSTD_blockHeaderSize; + remainingSize -= ZSTD_blockHeaderSize; + if (cBlockSize > remainingSize) + return ERROR(srcSize_wrong); + + switch (blockProperties.blockType) { + case bt_compressed: decodedSize = ZSTD_decompressBlock_internal(dctx, op, oend - op, ip, cBlockSize); break; + case bt_raw: decodedSize = ZSTD_copyRawBlock(op, oend - op, ip, cBlockSize); break; + case bt_rle: decodedSize = ZSTD_generateNxBytes(op, oend - op, *ip, blockProperties.origSize); break; + case bt_reserved: + default: return ERROR(corruption_detected); + } + + if (ZSTD_isError(decodedSize)) + return decodedSize; + if (dctx->fParams.checksumFlag) + xxh64_update(&dctx->xxhState, op, decodedSize); + op += decodedSize; + ip += cBlockSize; + remainingSize -= cBlockSize; + if (blockProperties.lastBlock) + break; + } + + if (dctx->fParams.checksumFlag) { /* Frame content checksum verification */ + U32 const checkCalc = (U32)xxh64_digest(&dctx->xxhState); + U32 checkRead; + if (remainingSize < 4) + return ERROR(checksum_wrong); + checkRead = ZSTD_readLE32(ip); + if (checkRead != checkCalc) + return ERROR(checksum_wrong); + ip += 4; + remainingSize -= 4; + } + + /* Allow caller to get size read */ + *srcPtr = ip; + *srcSizePtr = remainingSize; + return op - ostart; +} + +static const void *ZSTD_DDictDictContent(const ZSTD_DDict *ddict); +static size_t ZSTD_DDictDictSize(const ZSTD_DDict *ddict); + +static size_t ZSTD_decompressMultiFrame(ZSTD_DCtx *dctx, void *dst, size_t dstCapacity, const void *src, size_t srcSize, const void *dict, size_t dictSize, + const ZSTD_DDict *ddict) +{ + void *const dststart = dst; + + if (ddict) { + if (dict) { + /* programmer error, these two cases should be mutually exclusive */ + return ERROR(GENERIC); + } + + dict = ZSTD_DDictDictContent(ddict); + dictSize = ZSTD_DDictDictSize(ddict); + } + + while (srcSize >= ZSTD_frameHeaderSize_prefix) { + U32 magicNumber; + + magicNumber = ZSTD_readLE32(src); + if (magicNumber != ZSTD_MAGICNUMBER) { + if ((magicNumber & 0xFFFFFFF0U) == ZSTD_MAGIC_SKIPPABLE_START) { + size_t skippableSize; + if (srcSize < ZSTD_skippableHeaderSize) + return ERROR(srcSize_wrong); + skippableSize = ZSTD_readLE32((const BYTE *)src + 4) + ZSTD_skippableHeaderSize; + if (srcSize < skippableSize) { + return ERROR(srcSize_wrong); + } + + src = (const BYTE *)src + skippableSize; + srcSize -= skippableSize; + continue; + } else { + return ERROR(prefix_unknown); + } + } + + if (ddict) { + /* we were called from ZSTD_decompress_usingDDict */ + ZSTD_refDDict(dctx, ddict); + } else { + /* this will initialize correctly with no dict if dict == NULL, so + * use this in all cases but ddict */ + CHECK_F(ZSTD_decompressBegin_usingDict(dctx, dict, dictSize)); + } + ZSTD_checkContinuity(dctx, dst); + + { + const size_t res = ZSTD_decompressFrame(dctx, dst, dstCapacity, &src, &srcSize); + if (ZSTD_isError(res)) + return res; + /* don't need to bounds check this, ZSTD_decompressFrame will have + * already */ + dst = (BYTE *)dst + res; + dstCapacity -= res; + } + } + + if (srcSize) + return ERROR(srcSize_wrong); /* input not entirely consumed */ + + return (BYTE *)dst - (BYTE *)dststart; +} + +size_t ZSTD_decompress_usingDict(ZSTD_DCtx *dctx, void *dst, size_t dstCapacity, const void *src, size_t srcSize, const void *dict, size_t dictSize) +{ + return ZSTD_decompressMultiFrame(dctx, dst, dstCapacity, src, srcSize, dict, dictSize, NULL); +} + +size_t ZSTD_decompressDCtx(ZSTD_DCtx *dctx, void *dst, size_t dstCapacity, const void *src, size_t srcSize) +{ + return ZSTD_decompress_usingDict(dctx, dst, dstCapacity, src, srcSize, NULL, 0); +} + +/*-************************************** +* Advanced Streaming Decompression API +* Bufferless and synchronous +****************************************/ +size_t ZSTD_nextSrcSizeToDecompress(ZSTD_DCtx *dctx) { return dctx->expected; } + +ZSTD_nextInputType_e ZSTD_nextInputType(ZSTD_DCtx *dctx) +{ + switch (dctx->stage) { + default: /* should not happen */ + case ZSTDds_getFrameHeaderSize: + case ZSTDds_decodeFrameHeader: return ZSTDnit_frameHeader; + case ZSTDds_decodeBlockHeader: return ZSTDnit_blockHeader; + case ZSTDds_decompressBlock: return ZSTDnit_block; + case ZSTDds_decompressLastBlock: return ZSTDnit_lastBlock; + case ZSTDds_checkChecksum: return ZSTDnit_checksum; + case ZSTDds_decodeSkippableHeader: + case ZSTDds_skipFrame: return ZSTDnit_skippableFrame; + } +} + +int ZSTD_isSkipFrame(ZSTD_DCtx *dctx) { return dctx->stage == ZSTDds_skipFrame; } /* for zbuff */ + +/** ZSTD_decompressContinue() : +* @return : nb of bytes generated into `dst` (necessarily <= `dstCapacity) +* or an error code, which can be tested using ZSTD_isError() */ +size_t ZSTD_decompressContinue(ZSTD_DCtx *dctx, void *dst, size_t dstCapacity, const void *src, size_t srcSize) +{ + /* Sanity check */ + if (srcSize != dctx->expected) + return ERROR(srcSize_wrong); + if (dstCapacity) + ZSTD_checkContinuity(dctx, dst); + + switch (dctx->stage) { + case ZSTDds_getFrameHeaderSize: + if (srcSize != ZSTD_frameHeaderSize_prefix) + return ERROR(srcSize_wrong); /* impossible */ + if ((ZSTD_readLE32(src) & 0xFFFFFFF0U) == ZSTD_MAGIC_SKIPPABLE_START) { /* skippable frame */ + memcpy(dctx->headerBuffer, src, ZSTD_frameHeaderSize_prefix); + dctx->expected = ZSTD_skippableHeaderSize - ZSTD_frameHeaderSize_prefix; /* magic number + skippable frame length */ + dctx->stage = ZSTDds_decodeSkippableHeader; + return 0; + } + dctx->headerSize = ZSTD_frameHeaderSize(src, ZSTD_frameHeaderSize_prefix); + if (ZSTD_isError(dctx->headerSize)) + return dctx->headerSize; + memcpy(dctx->headerBuffer, src, ZSTD_frameHeaderSize_prefix); + if (dctx->headerSize > ZSTD_frameHeaderSize_prefix) { + dctx->expected = dctx->headerSize - ZSTD_frameHeaderSize_prefix; + dctx->stage = ZSTDds_decodeFrameHeader; + return 0; + } + dctx->expected = 0; /* not necessary to copy more */ + + case ZSTDds_decodeFrameHeader: + memcpy(dctx->headerBuffer + ZSTD_frameHeaderSize_prefix, src, dctx->expected); + CHECK_F(ZSTD_decodeFrameHeader(dctx, dctx->headerBuffer, dctx->headerSize)); + dctx->expected = ZSTD_blockHeaderSize; + dctx->stage = ZSTDds_decodeBlockHeader; + return 0; + + case ZSTDds_decodeBlockHeader: { + blockProperties_t bp; + size_t const cBlockSize = ZSTD_getcBlockSize(src, ZSTD_blockHeaderSize, &bp); + if (ZSTD_isError(cBlockSize)) + return cBlockSize; + dctx->expected = cBlockSize; + dctx->bType = bp.blockType; + dctx->rleSize = bp.origSize; + if (cBlockSize) { + dctx->stage = bp.lastBlock ? ZSTDds_decompressLastBlock : ZSTDds_decompressBlock; + return 0; + } + /* empty block */ + if (bp.lastBlock) { + if (dctx->fParams.checksumFlag) { + dctx->expected = 4; + dctx->stage = ZSTDds_checkChecksum; + } else { + dctx->expected = 0; /* end of frame */ + dctx->stage = ZSTDds_getFrameHeaderSize; + } + } else { + dctx->expected = 3; /* go directly to next header */ + dctx->stage = ZSTDds_decodeBlockHeader; + } + return 0; + } + case ZSTDds_decompressLastBlock: + case ZSTDds_decompressBlock: { + size_t rSize; + switch (dctx->bType) { + case bt_compressed: rSize = ZSTD_decompressBlock_internal(dctx, dst, dstCapacity, src, srcSize); break; + case bt_raw: rSize = ZSTD_copyRawBlock(dst, dstCapacity, src, srcSize); break; + case bt_rle: rSize = ZSTD_setRleBlock(dst, dstCapacity, src, srcSize, dctx->rleSize); break; + case bt_reserved: /* should never happen */ + default: return ERROR(corruption_detected); + } + if (ZSTD_isError(rSize)) + return rSize; + if (dctx->fParams.checksumFlag) + xxh64_update(&dctx->xxhState, dst, rSize); + + if (dctx->stage == ZSTDds_decompressLastBlock) { /* end of frame */ + if (dctx->fParams.checksumFlag) { /* another round for frame checksum */ + dctx->expected = 4; + dctx->stage = ZSTDds_checkChecksum; + } else { + dctx->expected = 0; /* ends here */ + dctx->stage = ZSTDds_getFrameHeaderSize; + } + } else { + dctx->stage = ZSTDds_decodeBlockHeader; + dctx->expected = ZSTD_blockHeaderSize; + dctx->previousDstEnd = (char *)dst + rSize; + } + return rSize; + } + case ZSTDds_checkChecksum: { + U32 const h32 = (U32)xxh64_digest(&dctx->xxhState); + U32 const check32 = ZSTD_readLE32(src); /* srcSize == 4, guaranteed by dctx->expected */ + if (check32 != h32) + return ERROR(checksum_wrong); + dctx->expected = 0; + dctx->stage = ZSTDds_getFrameHeaderSize; + return 0; + } + case ZSTDds_decodeSkippableHeader: { + memcpy(dctx->headerBuffer + ZSTD_frameHeaderSize_prefix, src, dctx->expected); + dctx->expected = ZSTD_readLE32(dctx->headerBuffer + 4); + dctx->stage = ZSTDds_skipFrame; + return 0; + } + case ZSTDds_skipFrame: { + dctx->expected = 0; + dctx->stage = ZSTDds_getFrameHeaderSize; + return 0; + } + default: + return ERROR(GENERIC); /* impossible */ + } +} + +static size_t ZSTD_refDictContent(ZSTD_DCtx *dctx, const void *dict, size_t dictSize) +{ + dctx->dictEnd = dctx->previousDstEnd; + dctx->vBase = (const char *)dict - ((const char *)(dctx->previousDstEnd) - (const char *)(dctx->base)); + dctx->base = dict; + dctx->previousDstEnd = (const char *)dict + dictSize; + return 0; +} + +/* ZSTD_loadEntropy() : + * dict : must point at beginning of a valid zstd dictionary + * @return : size of entropy tables read */ +static size_t ZSTD_loadEntropy(ZSTD_entropyTables_t *entropy, const void *const dict, size_t const dictSize) +{ + const BYTE *dictPtr = (const BYTE *)dict; + const BYTE *const dictEnd = dictPtr + dictSize; + + if (dictSize <= 8) + return ERROR(dictionary_corrupted); + dictPtr += 8; /* skip header = magic + dictID */ + + { + size_t const hSize = HUF_readDTableX4_wksp(entropy->hufTable, dictPtr, dictEnd - dictPtr, entropy->workspace, sizeof(entropy->workspace)); + if (HUF_isError(hSize)) + return ERROR(dictionary_corrupted); + dictPtr += hSize; + } + + { + short offcodeNCount[MaxOff + 1]; + U32 offcodeMaxValue = MaxOff, offcodeLog; + size_t const offcodeHeaderSize = FSE_readNCount(offcodeNCount, &offcodeMaxValue, &offcodeLog, dictPtr, dictEnd - dictPtr); + if (FSE_isError(offcodeHeaderSize)) + return ERROR(dictionary_corrupted); + if (offcodeLog > OffFSELog) + return ERROR(dictionary_corrupted); + CHECK_E(FSE_buildDTable_wksp(entropy->OFTable, offcodeNCount, offcodeMaxValue, offcodeLog, entropy->workspace, sizeof(entropy->workspace)), dictionary_corrupted); + dictPtr += offcodeHeaderSize; + } + + { + short matchlengthNCount[MaxML + 1]; + unsigned matchlengthMaxValue = MaxML, matchlengthLog; + size_t const matchlengthHeaderSize = FSE_readNCount(matchlengthNCount, &matchlengthMaxValue, &matchlengthLog, dictPtr, dictEnd - dictPtr); + if (FSE_isError(matchlengthHeaderSize)) + return ERROR(dictionary_corrupted); + if (matchlengthLog > MLFSELog) + return ERROR(dictionary_corrupted); + CHECK_E(FSE_buildDTable_wksp(entropy->MLTable, matchlengthNCount, matchlengthMaxValue, matchlengthLog, entropy->workspace, sizeof(entropy->workspace)), dictionary_corrupted); + dictPtr += matchlengthHeaderSize; + } + + { + short litlengthNCount[MaxLL + 1]; + unsigned litlengthMaxValue = MaxLL, litlengthLog; + size_t const litlengthHeaderSize = FSE_readNCount(litlengthNCount, &litlengthMaxValue, &litlengthLog, dictPtr, dictEnd - dictPtr); + if (FSE_isError(litlengthHeaderSize)) + return ERROR(dictionary_corrupted); + if (litlengthLog > LLFSELog) + return ERROR(dictionary_corrupted); + CHECK_E(FSE_buildDTable_wksp(entropy->LLTable, litlengthNCount, litlengthMaxValue, litlengthLog, entropy->workspace, sizeof(entropy->workspace)), dictionary_corrupted); + dictPtr += litlengthHeaderSize; + } + + if (dictPtr + 12 > dictEnd) + return ERROR(dictionary_corrupted); + { + int i; + size_t const dictContentSize = (size_t)(dictEnd - (dictPtr + 12)); + for (i = 0; i < 3; i++) { + U32 const rep = ZSTD_readLE32(dictPtr); + dictPtr += 4; + if (rep == 0 || rep >= dictContentSize) + return ERROR(dictionary_corrupted); + entropy->rep[i] = rep; + } + } + + return dictPtr - (const BYTE *)dict; +} + +static size_t ZSTD_decompress_insertDictionary(ZSTD_DCtx *dctx, const void *dict, size_t dictSize) +{ + if (dictSize < 8) + return ZSTD_refDictContent(dctx, dict, dictSize); + { + U32 const magic = ZSTD_readLE32(dict); + if (magic != ZSTD_DICT_MAGIC) { + return ZSTD_refDictContent(dctx, dict, dictSize); /* pure content mode */ + } + } + dctx->dictID = ZSTD_readLE32((const char *)dict + 4); + + /* load entropy tables */ + { + size_t const eSize = ZSTD_loadEntropy(&dctx->entropy, dict, dictSize); + if (ZSTD_isError(eSize)) + return ERROR(dictionary_corrupted); + dict = (const char *)dict + eSize; + dictSize -= eSize; + } + dctx->litEntropy = dctx->fseEntropy = 1; + + /* reference dictionary content */ + return ZSTD_refDictContent(dctx, dict, dictSize); +} + +size_t ZSTD_decompressBegin_usingDict(ZSTD_DCtx *dctx, const void *dict, size_t dictSize) +{ + CHECK_F(ZSTD_decompressBegin(dctx)); + if (dict && dictSize) + CHECK_E(ZSTD_decompress_insertDictionary(dctx, dict, dictSize), dictionary_corrupted); + return 0; +} + +/* ====== ZSTD_DDict ====== */ + +struct ZSTD_DDict_s { + void *dictBuffer; + const void *dictContent; + size_t dictSize; + ZSTD_entropyTables_t entropy; + U32 dictID; + U32 entropyPresent; + ZSTD_customMem cMem; +}; /* typedef'd to ZSTD_DDict within "zstd.h" */ + +size_t ZSTD_DDictWorkspaceBound(void) { return ZSTD_ALIGN(sizeof(ZSTD_stack)) + ZSTD_ALIGN(sizeof(ZSTD_DDict)); } + +static const void *ZSTD_DDictDictContent(const ZSTD_DDict *ddict) { return ddict->dictContent; } + +static size_t ZSTD_DDictDictSize(const ZSTD_DDict *ddict) { return ddict->dictSize; } + +static void ZSTD_refDDict(ZSTD_DCtx *dstDCtx, const ZSTD_DDict *ddict) +{ + ZSTD_decompressBegin(dstDCtx); /* init */ + if (ddict) { /* support refDDict on NULL */ + dstDCtx->dictID = ddict->dictID; + dstDCtx->base = ddict->dictContent; + dstDCtx->vBase = ddict->dictContent; + dstDCtx->dictEnd = (const BYTE *)ddict->dictContent + ddict->dictSize; + dstDCtx->previousDstEnd = dstDCtx->dictEnd; + if (ddict->entropyPresent) { + dstDCtx->litEntropy = 1; + dstDCtx->fseEntropy = 1; + dstDCtx->LLTptr = ddict->entropy.LLTable; + dstDCtx->MLTptr = ddict->entropy.MLTable; + dstDCtx->OFTptr = ddict->entropy.OFTable; + dstDCtx->HUFptr = ddict->entropy.hufTable; + dstDCtx->entropy.rep[0] = ddict->entropy.rep[0]; + dstDCtx->entropy.rep[1] = ddict->entropy.rep[1]; + dstDCtx->entropy.rep[2] = ddict->entropy.rep[2]; + } else { + dstDCtx->litEntropy = 0; + dstDCtx->fseEntropy = 0; + } + } +} + +static size_t ZSTD_loadEntropy_inDDict(ZSTD_DDict *ddict) +{ + ddict->dictID = 0; + ddict->entropyPresent = 0; + if (ddict->dictSize < 8) + return 0; + { + U32 const magic = ZSTD_readLE32(ddict->dictContent); + if (magic != ZSTD_DICT_MAGIC) + return 0; /* pure content mode */ + } + ddict->dictID = ZSTD_readLE32((const char *)ddict->dictContent + 4); + + /* load entropy tables */ + CHECK_E(ZSTD_loadEntropy(&ddict->entropy, ddict->dictContent, ddict->dictSize), dictionary_corrupted); + ddict->entropyPresent = 1; + return 0; +} + +static ZSTD_DDict *ZSTD_createDDict_advanced(const void *dict, size_t dictSize, unsigned byReference, ZSTD_customMem customMem) +{ + if (!customMem.customAlloc || !customMem.customFree) + return NULL; + + { + ZSTD_DDict *const ddict = (ZSTD_DDict *)ZSTD_malloc(sizeof(ZSTD_DDict), customMem); + if (!ddict) + return NULL; + ddict->cMem = customMem; + + if ((byReference) || (!dict) || (!dictSize)) { + ddict->dictBuffer = NULL; + ddict->dictContent = dict; + } else { + void *const internalBuffer = ZSTD_malloc(dictSize, customMem); + if (!internalBuffer) { + ZSTD_freeDDict(ddict); + return NULL; + } + memcpy(internalBuffer, dict, dictSize); + ddict->dictBuffer = internalBuffer; + ddict->dictContent = internalBuffer; + } + ddict->dictSize = dictSize; + ddict->entropy.hufTable[0] = (HUF_DTable)((HufLog)*0x1000001); /* cover both little and big endian */ + /* parse dictionary content */ + { + size_t const errorCode = ZSTD_loadEntropy_inDDict(ddict); + if (ZSTD_isError(errorCode)) { + ZSTD_freeDDict(ddict); + return NULL; + } + } + + return ddict; + } +} + +/*! ZSTD_initDDict() : +* Create a digested dictionary, to start decompression without startup delay. +* `dict` content is copied inside DDict. +* Consequently, `dict` can be released after `ZSTD_DDict` creation */ +ZSTD_DDict *ZSTD_initDDict(const void *dict, size_t dictSize, void *workspace, size_t workspaceSize) +{ + ZSTD_customMem const stackMem = ZSTD_initStack(workspace, workspaceSize); + return ZSTD_createDDict_advanced(dict, dictSize, 1, stackMem); +} + +size_t ZSTD_freeDDict(ZSTD_DDict *ddict) +{ + if (ddict == NULL) + return 0; /* support free on NULL */ + { + ZSTD_customMem const cMem = ddict->cMem; + ZSTD_free(ddict->dictBuffer, cMem); + ZSTD_free(ddict, cMem); + return 0; + } +} + +/*! ZSTD_getDictID_fromDict() : + * Provides the dictID stored within dictionary. + * if @return == 0, the dictionary is not conformant with Zstandard specification. + * It can still be loaded, but as a content-only dictionary. */ +unsigned ZSTD_getDictID_fromDict(const void *dict, size_t dictSize) +{ + if (dictSize < 8) + return 0; + if (ZSTD_readLE32(dict) != ZSTD_DICT_MAGIC) + return 0; + return ZSTD_readLE32((const char *)dict + 4); +} + +/*! ZSTD_getDictID_fromDDict() : + * Provides the dictID of the dictionary loaded into `ddict`. + * If @return == 0, the dictionary is not conformant to Zstandard specification, or empty. + * Non-conformant dictionaries can still be loaded, but as content-only dictionaries. */ +unsigned ZSTD_getDictID_fromDDict(const ZSTD_DDict *ddict) +{ + if (ddict == NULL) + return 0; + return ZSTD_getDictID_fromDict(ddict->dictContent, ddict->dictSize); +} + +/*! ZSTD_getDictID_fromFrame() : + * Provides the dictID required to decompressed the frame stored within `src`. + * If @return == 0, the dictID could not be decoded. + * This could for one of the following reasons : + * - The frame does not require a dictionary to be decoded (most common case). + * - The frame was built with dictID intentionally removed. Whatever dictionary is necessary is a hidden information. + * Note : this use case also happens when using a non-conformant dictionary. + * - `srcSize` is too small, and as a result, the frame header could not be decoded (only possible if `srcSize < ZSTD_FRAMEHEADERSIZE_MAX`). + * - This is not a Zstandard frame. + * When identifying the exact failure cause, it's possible to used ZSTD_getFrameParams(), which will provide a more precise error code. */ +unsigned ZSTD_getDictID_fromFrame(const void *src, size_t srcSize) +{ + ZSTD_frameParams zfp = {0, 0, 0, 0}; + size_t const hError = ZSTD_getFrameParams(&zfp, src, srcSize); + if (ZSTD_isError(hError)) + return 0; + return zfp.dictID; +} + +/*! ZSTD_decompress_usingDDict() : +* Decompression using a pre-digested Dictionary +* Use dictionary without significant overhead. */ +size_t ZSTD_decompress_usingDDict(ZSTD_DCtx *dctx, void *dst, size_t dstCapacity, const void *src, size_t srcSize, const ZSTD_DDict *ddict) +{ + /* pass content and size in case legacy frames are encountered */ + return ZSTD_decompressMultiFrame(dctx, dst, dstCapacity, src, srcSize, NULL, 0, ddict); +} + +/*===================================== +* Streaming decompression +*====================================*/ + +typedef enum { zdss_init, zdss_loadHeader, zdss_read, zdss_load, zdss_flush } ZSTD_dStreamStage; + +/* *** Resource management *** */ +struct ZSTD_DStream_s { + ZSTD_DCtx *dctx; + ZSTD_DDict *ddictLocal; + const ZSTD_DDict *ddict; + ZSTD_frameParams fParams; + ZSTD_dStreamStage stage; + char *inBuff; + size_t inBuffSize; + size_t inPos; + size_t maxWindowSize; + char *outBuff; + size_t outBuffSize; + size_t outStart; + size_t outEnd; + size_t blockSize; + BYTE headerBuffer[ZSTD_FRAMEHEADERSIZE_MAX]; /* tmp buffer to store frame header */ + size_t lhSize; + ZSTD_customMem customMem; + void *legacyContext; + U32 previousLegacyVersion; + U32 legacyVersion; + U32 hostageByte; +}; /* typedef'd to ZSTD_DStream within "zstd.h" */ + +size_t ZSTD_DStreamWorkspaceBound(size_t maxWindowSize) +{ + size_t const blockSize = MIN(maxWindowSize, ZSTD_BLOCKSIZE_ABSOLUTEMAX); + size_t const inBuffSize = blockSize; + size_t const outBuffSize = maxWindowSize + blockSize + WILDCOPY_OVERLENGTH * 2; + return ZSTD_DCtxWorkspaceBound() + ZSTD_ALIGN(sizeof(ZSTD_DStream)) + ZSTD_ALIGN(inBuffSize) + ZSTD_ALIGN(outBuffSize); +} + +static ZSTD_DStream *ZSTD_createDStream_advanced(ZSTD_customMem customMem) +{ + ZSTD_DStream *zds; + + if (!customMem.customAlloc || !customMem.customFree) + return NULL; + + zds = (ZSTD_DStream *)ZSTD_malloc(sizeof(ZSTD_DStream), customMem); + if (zds == NULL) + return NULL; + memset(zds, 0, sizeof(ZSTD_DStream)); + memcpy(&zds->customMem, &customMem, sizeof(ZSTD_customMem)); + zds->dctx = ZSTD_createDCtx_advanced(customMem); + if (zds->dctx == NULL) { + ZSTD_freeDStream(zds); + return NULL; + } + zds->stage = zdss_init; + zds->maxWindowSize = ZSTD_MAXWINDOWSIZE_DEFAULT; + return zds; +} + +ZSTD_DStream *ZSTD_initDStream(size_t maxWindowSize, void *workspace, size_t workspaceSize) +{ + ZSTD_customMem const stackMem = ZSTD_initStack(workspace, workspaceSize); + ZSTD_DStream *zds = ZSTD_createDStream_advanced(stackMem); + if (!zds) { + return NULL; + } + + zds->maxWindowSize = maxWindowSize; + zds->stage = zdss_loadHeader; + zds->lhSize = zds->inPos = zds->outStart = zds->outEnd = 0; + ZSTD_freeDDict(zds->ddictLocal); + zds->ddictLocal = NULL; + zds->ddict = zds->ddictLocal; + zds->legacyVersion = 0; + zds->hostageByte = 0; + + { + size_t const blockSize = MIN(zds->maxWindowSize, ZSTD_BLOCKSIZE_ABSOLUTEMAX); + size_t const neededOutSize = zds->maxWindowSize + blockSize + WILDCOPY_OVERLENGTH * 2; + + zds->inBuff = (char *)ZSTD_malloc(blockSize, zds->customMem); + zds->inBuffSize = blockSize; + zds->outBuff = (char *)ZSTD_malloc(neededOutSize, zds->customMem); + zds->outBuffSize = neededOutSize; + if (zds->inBuff == NULL || zds->outBuff == NULL) { + ZSTD_freeDStream(zds); + return NULL; + } + } + return zds; +} + +ZSTD_DStream *ZSTD_initDStream_usingDDict(size_t maxWindowSize, const ZSTD_DDict *ddict, void *workspace, size_t workspaceSize) +{ + ZSTD_DStream *zds = ZSTD_initDStream(maxWindowSize, workspace, workspaceSize); + if (zds) { + zds->ddict = ddict; + } + return zds; +} + +size_t ZSTD_freeDStream(ZSTD_DStream *zds) +{ + if (zds == NULL) + return 0; /* support free on null */ + { + ZSTD_customMem const cMem = zds->customMem; + ZSTD_freeDCtx(zds->dctx); + zds->dctx = NULL; + ZSTD_freeDDict(zds->ddictLocal); + zds->ddictLocal = NULL; + ZSTD_free(zds->inBuff, cMem); + zds->inBuff = NULL; + ZSTD_free(zds->outBuff, cMem); + zds->outBuff = NULL; + ZSTD_free(zds, cMem); + return 0; + } +} + +/* *** Initialization *** */ + +size_t ZSTD_DStreamInSize(void) { return ZSTD_BLOCKSIZE_ABSOLUTEMAX + ZSTD_blockHeaderSize; } +size_t ZSTD_DStreamOutSize(void) { return ZSTD_BLOCKSIZE_ABSOLUTEMAX; } + +size_t ZSTD_resetDStream(ZSTD_DStream *zds) +{ + zds->stage = zdss_loadHeader; + zds->lhSize = zds->inPos = zds->outStart = zds->outEnd = 0; + zds->legacyVersion = 0; + zds->hostageByte = 0; + return ZSTD_frameHeaderSize_prefix; +} + +/* ***** Decompression ***** */ + +ZSTD_STATIC size_t ZSTD_limitCopy(void *dst, size_t dstCapacity, const void *src, size_t srcSize) +{ + size_t const length = MIN(dstCapacity, srcSize); + memcpy(dst, src, length); + return length; +} + +size_t ZSTD_decompressStream(ZSTD_DStream *zds, ZSTD_outBuffer *output, ZSTD_inBuffer *input) +{ + const char *const istart = (const char *)(input->src) + input->pos; + const char *const iend = (const char *)(input->src) + input->size; + const char *ip = istart; + char *const ostart = (char *)(output->dst) + output->pos; + char *const oend = (char *)(output->dst) + output->size; + char *op = ostart; + U32 someMoreWork = 1; + + while (someMoreWork) { + switch (zds->stage) { + case zdss_init: + ZSTD_resetDStream(zds); /* transparent reset on starting decoding a new frame */ + /* fall-through */ + + case zdss_loadHeader: { + size_t const hSize = ZSTD_getFrameParams(&zds->fParams, zds->headerBuffer, zds->lhSize); + if (ZSTD_isError(hSize)) + return hSize; + if (hSize != 0) { /* need more input */ + size_t const toLoad = hSize - zds->lhSize; /* if hSize!=0, hSize > zds->lhSize */ + if (toLoad > (size_t)(iend - ip)) { /* not enough input to load full header */ + memcpy(zds->headerBuffer + zds->lhSize, ip, iend - ip); + zds->lhSize += iend - ip; + input->pos = input->size; + return (MAX(ZSTD_frameHeaderSize_min, hSize) - zds->lhSize) + + ZSTD_blockHeaderSize; /* remaining header bytes + next block header */ + } + memcpy(zds->headerBuffer + zds->lhSize, ip, toLoad); + zds->lhSize = hSize; + ip += toLoad; + break; + } + + /* check for single-pass mode opportunity */ + if (zds->fParams.frameContentSize && zds->fParams.windowSize /* skippable frame if == 0 */ + && (U64)(size_t)(oend - op) >= zds->fParams.frameContentSize) { + size_t const cSize = ZSTD_findFrameCompressedSize(istart, iend - istart); + if (cSize <= (size_t)(iend - istart)) { + size_t const decompressedSize = ZSTD_decompress_usingDDict(zds->dctx, op, oend - op, istart, cSize, zds->ddict); + if (ZSTD_isError(decompressedSize)) + return decompressedSize; + ip = istart + cSize; + op += decompressedSize; + zds->dctx->expected = 0; + zds->stage = zdss_init; + someMoreWork = 0; + break; + } + } + + /* Consume header */ + ZSTD_refDDict(zds->dctx, zds->ddict); + { + size_t const h1Size = ZSTD_nextSrcSizeToDecompress(zds->dctx); /* == ZSTD_frameHeaderSize_prefix */ + CHECK_F(ZSTD_decompressContinue(zds->dctx, NULL, 0, zds->headerBuffer, h1Size)); + { + size_t const h2Size = ZSTD_nextSrcSizeToDecompress(zds->dctx); + CHECK_F(ZSTD_decompressContinue(zds->dctx, NULL, 0, zds->headerBuffer + h1Size, h2Size)); + } + } + + zds->fParams.windowSize = MAX(zds->fParams.windowSize, 1U << ZSTD_WINDOWLOG_ABSOLUTEMIN); + if (zds->fParams.windowSize > zds->maxWindowSize) + return ERROR(frameParameter_windowTooLarge); + + /* Buffers are preallocated, but double check */ + { + size_t const blockSize = MIN(zds->maxWindowSize, ZSTD_BLOCKSIZE_ABSOLUTEMAX); + size_t const neededOutSize = zds->maxWindowSize + blockSize + WILDCOPY_OVERLENGTH * 2; + if (zds->inBuffSize < blockSize) { + return ERROR(GENERIC); + } + if (zds->outBuffSize < neededOutSize) { + return ERROR(GENERIC); + } + zds->blockSize = blockSize; + } + zds->stage = zdss_read; + } + /* pass-through */ + + case zdss_read: { + size_t const neededInSize = ZSTD_nextSrcSizeToDecompress(zds->dctx); + if (neededInSize == 0) { /* end of frame */ + zds->stage = zdss_init; + someMoreWork = 0; + break; + } + if ((size_t)(iend - ip) >= neededInSize) { /* decode directly from src */ + const int isSkipFrame = ZSTD_isSkipFrame(zds->dctx); + size_t const decodedSize = ZSTD_decompressContinue(zds->dctx, zds->outBuff + zds->outStart, + (isSkipFrame ? 0 : zds->outBuffSize - zds->outStart), ip, neededInSize); + if (ZSTD_isError(decodedSize)) + return decodedSize; + ip += neededInSize; + if (!decodedSize && !isSkipFrame) + break; /* this was just a header */ + zds->outEnd = zds->outStart + decodedSize; + zds->stage = zdss_flush; + break; + } + if (ip == iend) { + someMoreWork = 0; + break; + } /* no more input */ + zds->stage = zdss_load; + /* pass-through */ + } + + case zdss_load: { + size_t const neededInSize = ZSTD_nextSrcSizeToDecompress(zds->dctx); + size_t const toLoad = neededInSize - zds->inPos; /* should always be <= remaining space within inBuff */ + size_t loadedSize; + if (toLoad > zds->inBuffSize - zds->inPos) + return ERROR(corruption_detected); /* should never happen */ + loadedSize = ZSTD_limitCopy(zds->inBuff + zds->inPos, toLoad, ip, iend - ip); + ip += loadedSize; + zds->inPos += loadedSize; + if (loadedSize < toLoad) { + someMoreWork = 0; + break; + } /* not enough input, wait for more */ + + /* decode loaded input */ + { + const int isSkipFrame = ZSTD_isSkipFrame(zds->dctx); + size_t const decodedSize = ZSTD_decompressContinue(zds->dctx, zds->outBuff + zds->outStart, zds->outBuffSize - zds->outStart, + zds->inBuff, neededInSize); + if (ZSTD_isError(decodedSize)) + return decodedSize; + zds->inPos = 0; /* input is consumed */ + if (!decodedSize && !isSkipFrame) { + zds->stage = zdss_read; + break; + } /* this was just a header */ + zds->outEnd = zds->outStart + decodedSize; + zds->stage = zdss_flush; + /* pass-through */ + } + } + + case zdss_flush: { + size_t const toFlushSize = zds->outEnd - zds->outStart; + size_t const flushedSize = ZSTD_limitCopy(op, oend - op, zds->outBuff + zds->outStart, toFlushSize); + op += flushedSize; + zds->outStart += flushedSize; + if (flushedSize == toFlushSize) { /* flush completed */ + zds->stage = zdss_read; + if (zds->outStart + zds->blockSize > zds->outBuffSize) + zds->outStart = zds->outEnd = 0; + break; + } + /* cannot complete flush */ + someMoreWork = 0; + break; + } + default: + return ERROR(GENERIC); /* impossible */ + } + } + + /* result */ + input->pos += (size_t)(ip - istart); + output->pos += (size_t)(op - ostart); + { + size_t nextSrcSizeHint = ZSTD_nextSrcSizeToDecompress(zds->dctx); + if (!nextSrcSizeHint) { /* frame fully decoded */ + if (zds->outEnd == zds->outStart) { /* output fully flushed */ + if (zds->hostageByte) { + if (input->pos >= input->size) { + zds->stage = zdss_read; + return 1; + } /* can't release hostage (not present) */ + input->pos++; /* release hostage */ + } + return 0; + } + if (!zds->hostageByte) { /* output not fully flushed; keep last byte as hostage; will be released when all output is flushed */ + input->pos--; /* note : pos > 0, otherwise, impossible to finish reading last block */ + zds->hostageByte = 1; + } + return 1; + } + nextSrcSizeHint += ZSTD_blockHeaderSize * (ZSTD_nextInputType(zds->dctx) == ZSTDnit_block); /* preload header of next block */ + if (zds->inPos > nextSrcSizeHint) + return ERROR(GENERIC); /* should never happen */ + nextSrcSizeHint -= zds->inPos; /* already loaded*/ + return nextSrcSizeHint; + } +} + +EXPORT_SYMBOL(ZSTD_DCtxWorkspaceBound); +EXPORT_SYMBOL(ZSTD_initDCtx); +EXPORT_SYMBOL(ZSTD_decompressDCtx); +EXPORT_SYMBOL(ZSTD_decompress_usingDict); + +EXPORT_SYMBOL(ZSTD_DDictWorkspaceBound); +EXPORT_SYMBOL(ZSTD_initDDict); +EXPORT_SYMBOL(ZSTD_decompress_usingDDict); + +EXPORT_SYMBOL(ZSTD_DStreamWorkspaceBound); +EXPORT_SYMBOL(ZSTD_initDStream); +EXPORT_SYMBOL(ZSTD_initDStream_usingDDict); +EXPORT_SYMBOL(ZSTD_resetDStream); +EXPORT_SYMBOL(ZSTD_decompressStream); +EXPORT_SYMBOL(ZSTD_DStreamInSize); +EXPORT_SYMBOL(ZSTD_DStreamOutSize); + +EXPORT_SYMBOL(ZSTD_findFrameCompressedSize); +EXPORT_SYMBOL(ZSTD_getFrameContentSize); +EXPORT_SYMBOL(ZSTD_findDecompressedSize); + +EXPORT_SYMBOL(ZSTD_isFrame); +EXPORT_SYMBOL(ZSTD_getDictID_fromDict); +EXPORT_SYMBOL(ZSTD_getDictID_fromDDict); +EXPORT_SYMBOL(ZSTD_getDictID_fromFrame); + +EXPORT_SYMBOL(ZSTD_getFrameParams); +EXPORT_SYMBOL(ZSTD_decompressBegin); +EXPORT_SYMBOL(ZSTD_decompressBegin_usingDict); +EXPORT_SYMBOL(ZSTD_copyDCtx); +EXPORT_SYMBOL(ZSTD_nextSrcSizeToDecompress); +EXPORT_SYMBOL(ZSTD_decompressContinue); +EXPORT_SYMBOL(ZSTD_nextInputType); + +EXPORT_SYMBOL(ZSTD_decompressBlock); +EXPORT_SYMBOL(ZSTD_insertBlock); + +MODULE_LICENSE("Dual BSD/GPL"); +MODULE_DESCRIPTION("Zstd Decompressor"); diff --git a/lib/zstd/entropy_common.c b/lib/zstd/entropy_common.c new file mode 100644 index 000000000000..2b0a643c32c4 --- /dev/null +++ b/lib/zstd/entropy_common.c @@ -0,0 +1,243 @@ +/* + * Common functions of New Generation Entropy library + * Copyright (C) 2016, Yann Collet. + * + * BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php) + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * This program is free software; you can redistribute it and/or modify it under + * the terms of the GNU General Public License version 2 as published by the + * Free Software Foundation. This program is dual-licensed; you may select + * either version 2 of the GNU General Public License ("GPL") or BSD license + * ("BSD"). + * + * You can contact the author at : + * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy + */ + +/* ************************************* +* Dependencies +***************************************/ +#include "error_private.h" /* ERR_*, ERROR */ +#include "fse.h" +#include "huf.h" +#include "mem.h" + +/*=== Version ===*/ +unsigned FSE_versionNumber(void) { return FSE_VERSION_NUMBER; } + +/*=== Error Management ===*/ +unsigned FSE_isError(size_t code) { return ERR_isError(code); } + +unsigned HUF_isError(size_t code) { return ERR_isError(code); } + +/*-************************************************************** +* FSE NCount encoding-decoding +****************************************************************/ +size_t FSE_readNCount(short *normalizedCounter, unsigned *maxSVPtr, unsigned *tableLogPtr, const void *headerBuffer, size_t hbSize) +{ + const BYTE *const istart = (const BYTE *)headerBuffer; + const BYTE *const iend = istart + hbSize; + const BYTE *ip = istart; + int nbBits; + int remaining; + int threshold; + U32 bitStream; + int bitCount; + unsigned charnum = 0; + int previous0 = 0; + + if (hbSize < 4) + return ERROR(srcSize_wrong); + bitStream = ZSTD_readLE32(ip); + nbBits = (bitStream & 0xF) + FSE_MIN_TABLELOG; /* extract tableLog */ + if (nbBits > FSE_TABLELOG_ABSOLUTE_MAX) + return ERROR(tableLog_tooLarge); + bitStream >>= 4; + bitCount = 4; + *tableLogPtr = nbBits; + remaining = (1 << nbBits) + 1; + threshold = 1 << nbBits; + nbBits++; + + while ((remaining > 1) & (charnum <= *maxSVPtr)) { + if (previous0) { + unsigned n0 = charnum; + while ((bitStream & 0xFFFF) == 0xFFFF) { + n0 += 24; + if (ip < iend - 5) { + ip += 2; + bitStream = ZSTD_readLE32(ip) >> bitCount; + } else { + bitStream >>= 16; + bitCount += 16; + } + } + while ((bitStream & 3) == 3) { + n0 += 3; + bitStream >>= 2; + bitCount += 2; + } + n0 += bitStream & 3; + bitCount += 2; + if (n0 > *maxSVPtr) + return ERROR(maxSymbolValue_tooSmall); + while (charnum < n0) + normalizedCounter[charnum++] = 0; + if ((ip <= iend - 7) || (ip + (bitCount >> 3) <= iend - 4)) { + ip += bitCount >> 3; + bitCount &= 7; + bitStream = ZSTD_readLE32(ip) >> bitCount; + } else { + bitStream >>= 2; + } + } + { + int const max = (2 * threshold - 1) - remaining; + int count; + + if ((bitStream & (threshold - 1)) < (U32)max) { + count = bitStream & (threshold - 1); + bitCount += nbBits - 1; + } else { + count = bitStream & (2 * threshold - 1); + if (count >= threshold) + count -= max; + bitCount += nbBits; + } + + count--; /* extra accuracy */ + remaining -= count < 0 ? -count : count; /* -1 means +1 */ + normalizedCounter[charnum++] = (short)count; + previous0 = !count; + while (remaining < threshold) { + nbBits--; + threshold >>= 1; + } + + if ((ip <= iend - 7) || (ip + (bitCount >> 3) <= iend - 4)) { + ip += bitCount >> 3; + bitCount &= 7; + } else { + bitCount -= (int)(8 * (iend - 4 - ip)); + ip = iend - 4; + } + bitStream = ZSTD_readLE32(ip) >> (bitCount & 31); + } + } /* while ((remaining>1) & (charnum<=*maxSVPtr)) */ + if (remaining != 1) + return ERROR(corruption_detected); + if (bitCount > 32) + return ERROR(corruption_detected); + *maxSVPtr = charnum - 1; + + ip += (bitCount + 7) >> 3; + return ip - istart; +} + +/*! HUF_readStats() : + Read compact Huffman tree, saved by HUF_writeCTable(). + `huffWeight` is destination buffer. + `rankStats` is assumed to be a table of at least HUF_TABLELOG_MAX U32. + @return : size read from `src` , or an error Code . + Note : Needed by HUF_readCTable() and HUF_readDTableX?() . +*/ +size_t HUF_readStats_wksp(BYTE *huffWeight, size_t hwSize, U32 *rankStats, U32 *nbSymbolsPtr, U32 *tableLogPtr, const void *src, size_t srcSize, void *workspace, size_t workspaceSize) +{ + U32 weightTotal; + const BYTE *ip = (const BYTE *)src; + size_t iSize; + size_t oSize; + + if (!srcSize) + return ERROR(srcSize_wrong); + iSize = ip[0]; + /* memset(huffWeight, 0, hwSize); */ /* is not necessary, even though some analyzer complain ... */ + + if (iSize >= 128) { /* special header */ + oSize = iSize - 127; + iSize = ((oSize + 1) / 2); + if (iSize + 1 > srcSize) + return ERROR(srcSize_wrong); + if (oSize >= hwSize) + return ERROR(corruption_detected); + ip += 1; + { + U32 n; + for (n = 0; n < oSize; n += 2) { + huffWeight[n] = ip[n / 2] >> 4; + huffWeight[n + 1] = ip[n / 2] & 15; + } + } + } else { /* header compressed with FSE (normal case) */ + if (iSize + 1 > srcSize) + return ERROR(srcSize_wrong); + oSize = FSE_decompress_wksp(huffWeight, hwSize - 1, ip + 1, iSize, 6, workspace, workspaceSize); /* max (hwSize-1) values decoded, as last one is implied */ + if (FSE_isError(oSize)) + return oSize; + } + + /* collect weight stats */ + memset(rankStats, 0, (HUF_TABLELOG_MAX + 1) * sizeof(U32)); + weightTotal = 0; + { + U32 n; + for (n = 0; n < oSize; n++) { + if (huffWeight[n] >= HUF_TABLELOG_MAX) + return ERROR(corruption_detected); + rankStats[huffWeight[n]]++; + weightTotal += (1 << huffWeight[n]) >> 1; + } + } + if (weightTotal == 0) + return ERROR(corruption_detected); + + /* get last non-null symbol weight (implied, total must be 2^n) */ + { + U32 const tableLog = BIT_highbit32(weightTotal) + 1; + if (tableLog > HUF_TABLELOG_MAX) + return ERROR(corruption_detected); + *tableLogPtr = tableLog; + /* determine last weight */ + { + U32 const total = 1 << tableLog; + U32 const rest = total - weightTotal; + U32 const verif = 1 << BIT_highbit32(rest); + U32 const lastWeight = BIT_highbit32(rest) + 1; + if (verif != rest) + return ERROR(corruption_detected); /* last value must be a clean power of 2 */ + huffWeight[oSize] = (BYTE)lastWeight; + rankStats[lastWeight]++; + } + } + + /* check tree construction validity */ + if ((rankStats[1] < 2) || (rankStats[1] & 1)) + return ERROR(corruption_detected); /* by construction : at least 2 elts of rank 1, must be even */ + + /* results */ + *nbSymbolsPtr = (U32)(oSize + 1); + return iSize + 1; +} diff --git a/lib/zstd/error_private.h b/lib/zstd/error_private.h new file mode 100644 index 000000000000..1a60b31f706c --- /dev/null +++ b/lib/zstd/error_private.h @@ -0,0 +1,53 @@ +/** + * Copyright (c) 2016-present, Yann Collet, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of https://github.com/facebook/zstd. + * An additional grant of patent rights can be found in the PATENTS file in the + * same directory. + * + * This program is free software; you can redistribute it and/or modify it under + * the terms of the GNU General Public License version 2 as published by the + * Free Software Foundation. This program is dual-licensed; you may select + * either version 2 of the GNU General Public License ("GPL") or BSD license + * ("BSD"). + */ + +/* Note : this module is expected to remain private, do not expose it */ + +#ifndef ERROR_H_MODULE +#define ERROR_H_MODULE + +/* **************************************** +* Dependencies +******************************************/ +#include /* size_t */ +#include /* enum list */ + +/* **************************************** +* Compiler-specific +******************************************/ +#define ERR_STATIC static __attribute__((unused)) + +/*-**************************************** +* Customization (error_public.h) +******************************************/ +typedef ZSTD_ErrorCode ERR_enum; +#define PREFIX(name) ZSTD_error_##name + +/*-**************************************** +* Error codes handling +******************************************/ +#define ERROR(name) ((size_t)-PREFIX(name)) + +ERR_STATIC unsigned ERR_isError(size_t code) { return (code > ERROR(maxCode)); } + +ERR_STATIC ERR_enum ERR_getErrorCode(size_t code) +{ + if (!ERR_isError(code)) + return (ERR_enum)0; + return (ERR_enum)(0 - code); +} + +#endif /* ERROR_H_MODULE */ diff --git a/lib/zstd/fse.h b/lib/zstd/fse.h new file mode 100644 index 000000000000..7460ab04b191 --- /dev/null +++ b/lib/zstd/fse.h @@ -0,0 +1,575 @@ +/* + * FSE : Finite State Entropy codec + * Public Prototypes declaration + * Copyright (C) 2013-2016, Yann Collet. + * + * BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php) + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * This program is free software; you can redistribute it and/or modify it under + * the terms of the GNU General Public License version 2 as published by the + * Free Software Foundation. This program is dual-licensed; you may select + * either version 2 of the GNU General Public License ("GPL") or BSD license + * ("BSD"). + * + * You can contact the author at : + * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy + */ +#ifndef FSE_H +#define FSE_H + +/*-***************************************** +* Dependencies +******************************************/ +#include /* size_t, ptrdiff_t */ + +/*-***************************************** +* FSE_PUBLIC_API : control library symbols visibility +******************************************/ +#define FSE_PUBLIC_API + +/*------ Version ------*/ +#define FSE_VERSION_MAJOR 0 +#define FSE_VERSION_MINOR 9 +#define FSE_VERSION_RELEASE 0 + +#define FSE_LIB_VERSION FSE_VERSION_MAJOR.FSE_VERSION_MINOR.FSE_VERSION_RELEASE +#define FSE_QUOTE(str) #str +#define FSE_EXPAND_AND_QUOTE(str) FSE_QUOTE(str) +#define FSE_VERSION_STRING FSE_EXPAND_AND_QUOTE(FSE_LIB_VERSION) + +#define FSE_VERSION_NUMBER (FSE_VERSION_MAJOR * 100 * 100 + FSE_VERSION_MINOR * 100 + FSE_VERSION_RELEASE) +FSE_PUBLIC_API unsigned FSE_versionNumber(void); /**< library version number; to be used when checking dll version */ + +/*-***************************************** +* Tool functions +******************************************/ +FSE_PUBLIC_API size_t FSE_compressBound(size_t size); /* maximum compressed size */ + +/* Error Management */ +FSE_PUBLIC_API unsigned FSE_isError(size_t code); /* tells if a return value is an error code */ + +/*-***************************************** +* FSE detailed API +******************************************/ +/*! +FSE_compress() does the following: +1. count symbol occurrence from source[] into table count[] +2. normalize counters so that sum(count[]) == Power_of_2 (2^tableLog) +3. save normalized counters to memory buffer using writeNCount() +4. build encoding table 'CTable' from normalized counters +5. encode the data stream using encoding table 'CTable' + +FSE_decompress() does the following: +1. read normalized counters with readNCount() +2. build decoding table 'DTable' from normalized counters +3. decode the data stream using decoding table 'DTable' + +The following API allows targeting specific sub-functions for advanced tasks. +For example, it's possible to compress several blocks using the same 'CTable', +or to save and provide normalized distribution using external method. +*/ + +/* *** COMPRESSION *** */ +/*! FSE_optimalTableLog(): + dynamically downsize 'tableLog' when conditions are met. + It saves CPU time, by using smaller tables, while preserving or even improving compression ratio. + @return : recommended tableLog (necessarily <= 'maxTableLog') */ +FSE_PUBLIC_API unsigned FSE_optimalTableLog(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue); + +/*! FSE_normalizeCount(): + normalize counts so that sum(count[]) == Power_of_2 (2^tableLog) + 'normalizedCounter' is a table of short, of minimum size (maxSymbolValue+1). + @return : tableLog, + or an errorCode, which can be tested using FSE_isError() */ +FSE_PUBLIC_API size_t FSE_normalizeCount(short *normalizedCounter, unsigned tableLog, const unsigned *count, size_t srcSize, unsigned maxSymbolValue); + +/*! FSE_NCountWriteBound(): + Provides the maximum possible size of an FSE normalized table, given 'maxSymbolValue' and 'tableLog'. + Typically useful for allocation purpose. */ +FSE_PUBLIC_API size_t FSE_NCountWriteBound(unsigned maxSymbolValue, unsigned tableLog); + +/*! FSE_writeNCount(): + Compactly save 'normalizedCounter' into 'buffer'. + @return : size of the compressed table, + or an errorCode, which can be tested using FSE_isError(). */ +FSE_PUBLIC_API size_t FSE_writeNCount(void *buffer, size_t bufferSize, const short *normalizedCounter, unsigned maxSymbolValue, unsigned tableLog); + +/*! Constructor and Destructor of FSE_CTable. + Note that FSE_CTable size depends on 'tableLog' and 'maxSymbolValue' */ +typedef unsigned FSE_CTable; /* don't allocate that. It's only meant to be more restrictive than void* */ + +/*! FSE_compress_usingCTable(): + Compress `src` using `ct` into `dst` which must be already allocated. + @return : size of compressed data (<= `dstCapacity`), + or 0 if compressed data could not fit into `dst`, + or an errorCode, which can be tested using FSE_isError() */ +FSE_PUBLIC_API size_t FSE_compress_usingCTable(void *dst, size_t dstCapacity, const void *src, size_t srcSize, const FSE_CTable *ct); + +/*! +Tutorial : +---------- +The first step is to count all symbols. FSE_count() does this job very fast. +Result will be saved into 'count', a table of unsigned int, which must be already allocated, and have 'maxSymbolValuePtr[0]+1' cells. +'src' is a table of bytes of size 'srcSize'. All values within 'src' MUST be <= maxSymbolValuePtr[0] +maxSymbolValuePtr[0] will be updated, with its real value (necessarily <= original value) +FSE_count() will return the number of occurrence of the most frequent symbol. +This can be used to know if there is a single symbol within 'src', and to quickly evaluate its compressibility. +If there is an error, the function will return an ErrorCode (which can be tested using FSE_isError()). + +The next step is to normalize the frequencies. +FSE_normalizeCount() will ensure that sum of frequencies is == 2 ^'tableLog'. +It also guarantees a minimum of 1 to any Symbol with frequency >= 1. +You can use 'tableLog'==0 to mean "use default tableLog value". +If you are unsure of which tableLog value to use, you can ask FSE_optimalTableLog(), +which will provide the optimal valid tableLog given sourceSize, maxSymbolValue, and a user-defined maximum (0 means "default"). + +The result of FSE_normalizeCount() will be saved into a table, +called 'normalizedCounter', which is a table of signed short. +'normalizedCounter' must be already allocated, and have at least 'maxSymbolValue+1' cells. +The return value is tableLog if everything proceeded as expected. +It is 0 if there is a single symbol within distribution. +If there is an error (ex: invalid tableLog value), the function will return an ErrorCode (which can be tested using FSE_isError()). + +'normalizedCounter' can be saved in a compact manner to a memory area using FSE_writeNCount(). +'buffer' must be already allocated. +For guaranteed success, buffer size must be at least FSE_headerBound(). +The result of the function is the number of bytes written into 'buffer'. +If there is an error, the function will return an ErrorCode (which can be tested using FSE_isError(); ex : buffer size too small). + +'normalizedCounter' can then be used to create the compression table 'CTable'. +The space required by 'CTable' must be already allocated, using FSE_createCTable(). +You can then use FSE_buildCTable() to fill 'CTable'. +If there is an error, both functions will return an ErrorCode (which can be tested using FSE_isError()). + +'CTable' can then be used to compress 'src', with FSE_compress_usingCTable(). +Similar to FSE_count(), the convention is that 'src' is assumed to be a table of char of size 'srcSize' +The function returns the size of compressed data (without header), necessarily <= `dstCapacity`. +If it returns '0', compressed data could not fit into 'dst'. +If there is an error, the function will return an ErrorCode (which can be tested using FSE_isError()). +*/ + +/* *** DECOMPRESSION *** */ + +/*! FSE_readNCount(): + Read compactly saved 'normalizedCounter' from 'rBuffer'. + @return : size read from 'rBuffer', + or an errorCode, which can be tested using FSE_isError(). + maxSymbolValuePtr[0] and tableLogPtr[0] will also be updated with their respective values */ +FSE_PUBLIC_API size_t FSE_readNCount(short *normalizedCounter, unsigned *maxSymbolValuePtr, unsigned *tableLogPtr, const void *rBuffer, size_t rBuffSize); + +/*! Constructor and Destructor of FSE_DTable. + Note that its size depends on 'tableLog' */ +typedef unsigned FSE_DTable; /* don't allocate that. It's just a way to be more restrictive than void* */ + +/*! FSE_buildDTable(): + Builds 'dt', which must be already allocated, using FSE_createDTable(). + return : 0, or an errorCode, which can be tested using FSE_isError() */ +FSE_PUBLIC_API size_t FSE_buildDTable_wksp(FSE_DTable *dt, const short *normalizedCounter, unsigned maxSymbolValue, unsigned tableLog, void *workspace, size_t workspaceSize); + +/*! FSE_decompress_usingDTable(): + Decompress compressed source `cSrc` of size `cSrcSize` using `dt` + into `dst` which must be already allocated. + @return : size of regenerated data (necessarily <= `dstCapacity`), + or an errorCode, which can be tested using FSE_isError() */ +FSE_PUBLIC_API size_t FSE_decompress_usingDTable(void *dst, size_t dstCapacity, const void *cSrc, size_t cSrcSize, const FSE_DTable *dt); + +/*! +Tutorial : +---------- +(Note : these functions only decompress FSE-compressed blocks. + If block is uncompressed, use memcpy() instead + If block is a single repeated byte, use memset() instead ) + +The first step is to obtain the normalized frequencies of symbols. +This can be performed by FSE_readNCount() if it was saved using FSE_writeNCount(). +'normalizedCounter' must be already allocated, and have at least 'maxSymbolValuePtr[0]+1' cells of signed short. +In practice, that means it's necessary to know 'maxSymbolValue' beforehand, +or size the table to handle worst case situations (typically 256). +FSE_readNCount() will provide 'tableLog' and 'maxSymbolValue'. +The result of FSE_readNCount() is the number of bytes read from 'rBuffer'. +Note that 'rBufferSize' must be at least 4 bytes, even if useful information is less than that. +If there is an error, the function will return an error code, which can be tested using FSE_isError(). + +The next step is to build the decompression tables 'FSE_DTable' from 'normalizedCounter'. +This is performed by the function FSE_buildDTable(). +The space required by 'FSE_DTable' must be already allocated using FSE_createDTable(). +If there is an error, the function will return an error code, which can be tested using FSE_isError(). + +`FSE_DTable` can then be used to decompress `cSrc`, with FSE_decompress_usingDTable(). +`cSrcSize` must be strictly correct, otherwise decompression will fail. +FSE_decompress_usingDTable() result will tell how many bytes were regenerated (<=`dstCapacity`). +If there is an error, the function will return an error code, which can be tested using FSE_isError(). (ex: dst buffer too small) +*/ + +/* *** Dependency *** */ +#include "bitstream.h" + +/* ***************************************** +* Static allocation +*******************************************/ +/* FSE buffer bounds */ +#define FSE_NCOUNTBOUND 512 +#define FSE_BLOCKBOUND(size) (size + (size >> 7)) +#define FSE_COMPRESSBOUND(size) (FSE_NCOUNTBOUND + FSE_BLOCKBOUND(size)) /* Macro version, useful for static allocation */ + +/* It is possible to statically allocate FSE CTable/DTable as a table of FSE_CTable/FSE_DTable using below macros */ +#define FSE_CTABLE_SIZE_U32(maxTableLog, maxSymbolValue) (1 + (1 << (maxTableLog - 1)) + ((maxSymbolValue + 1) * 2)) +#define FSE_DTABLE_SIZE_U32(maxTableLog) (1 + (1 << maxTableLog)) + +/* ***************************************** +* FSE advanced API +*******************************************/ +/* FSE_count_wksp() : + * Same as FSE_count(), but using an externally provided scratch buffer. + * `workSpace` size must be table of >= `1024` unsigned + */ +size_t FSE_count_wksp(unsigned *count, unsigned *maxSymbolValuePtr, const void *source, size_t sourceSize, unsigned *workSpace); + +/* FSE_countFast_wksp() : + * Same as FSE_countFast(), but using an externally provided scratch buffer. + * `workSpace` must be a table of minimum `1024` unsigned + */ +size_t FSE_countFast_wksp(unsigned *count, unsigned *maxSymbolValuePtr, const void *src, size_t srcSize, unsigned *workSpace); + +/*! FSE_count_simple + * Same as FSE_countFast(), but does not use any additional memory (not even on stack). + * This function is unsafe, and will segfault if any value within `src` is `> *maxSymbolValuePtr` (presuming it's also the size of `count`). +*/ +size_t FSE_count_simple(unsigned *count, unsigned *maxSymbolValuePtr, const void *src, size_t srcSize); + +unsigned FSE_optimalTableLog_internal(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue, unsigned minus); +/**< same as FSE_optimalTableLog(), which used `minus==2` */ + +size_t FSE_buildCTable_raw(FSE_CTable *ct, unsigned nbBits); +/**< build a fake FSE_CTable, designed for a flat distribution, where each symbol uses nbBits */ + +size_t FSE_buildCTable_rle(FSE_CTable *ct, unsigned char symbolValue); +/**< build a fake FSE_CTable, designed to compress always the same symbolValue */ + +/* FSE_buildCTable_wksp() : + * Same as FSE_buildCTable(), but using an externally allocated scratch buffer (`workSpace`). + * `wkspSize` must be >= `(1<= BIT_DStream_completed + +When it's done, verify decompression is fully completed, by checking both DStream and the relevant states. +Checking if DStream has reached its end is performed by : + BIT_endOfDStream(&DStream); +Check also the states. There might be some symbols left there, if some high probability ones (>50%) are possible. + FSE_endOfDState(&DState); +*/ + +/* ***************************************** +* FSE unsafe API +*******************************************/ +static unsigned char FSE_decodeSymbolFast(FSE_DState_t *DStatePtr, BIT_DStream_t *bitD); +/* faster, but works only if nbBits is always >= 1 (otherwise, result will be corrupted) */ + +/* ***************************************** +* Implementation of inlined functions +*******************************************/ +typedef struct { + int deltaFindState; + U32 deltaNbBits; +} FSE_symbolCompressionTransform; /* total 8 bytes */ + +ZSTD_STATIC void FSE_initCState(FSE_CState_t *statePtr, const FSE_CTable *ct) +{ + const void *ptr = ct; + const U16 *u16ptr = (const U16 *)ptr; + const U32 tableLog = ZSTD_read16(ptr); + statePtr->value = (ptrdiff_t)1 << tableLog; + statePtr->stateTable = u16ptr + 2; + statePtr->symbolTT = ((const U32 *)ct + 1 + (tableLog ? (1 << (tableLog - 1)) : 1)); + statePtr->stateLog = tableLog; +} + +/*! FSE_initCState2() : +* Same as FSE_initCState(), but the first symbol to include (which will be the last to be read) +* uses the smallest state value possible, saving the cost of this symbol */ +ZSTD_STATIC void FSE_initCState2(FSE_CState_t *statePtr, const FSE_CTable *ct, U32 symbol) +{ + FSE_initCState(statePtr, ct); + { + const FSE_symbolCompressionTransform symbolTT = ((const FSE_symbolCompressionTransform *)(statePtr->symbolTT))[symbol]; + const U16 *stateTable = (const U16 *)(statePtr->stateTable); + U32 nbBitsOut = (U32)((symbolTT.deltaNbBits + (1 << 15)) >> 16); + statePtr->value = (nbBitsOut << 16) - symbolTT.deltaNbBits; + statePtr->value = stateTable[(statePtr->value >> nbBitsOut) + symbolTT.deltaFindState]; + } +} + +ZSTD_STATIC void FSE_encodeSymbol(BIT_CStream_t *bitC, FSE_CState_t *statePtr, U32 symbol) +{ + const FSE_symbolCompressionTransform symbolTT = ((const FSE_symbolCompressionTransform *)(statePtr->symbolTT))[symbol]; + const U16 *const stateTable = (const U16 *)(statePtr->stateTable); + U32 nbBitsOut = (U32)((statePtr->value + symbolTT.deltaNbBits) >> 16); + BIT_addBits(bitC, statePtr->value, nbBitsOut); + statePtr->value = stateTable[(statePtr->value >> nbBitsOut) + symbolTT.deltaFindState]; +} + +ZSTD_STATIC void FSE_flushCState(BIT_CStream_t *bitC, const FSE_CState_t *statePtr) +{ + BIT_addBits(bitC, statePtr->value, statePtr->stateLog); + BIT_flushBits(bitC); +} + +/* ====== Decompression ====== */ + +typedef struct { + U16 tableLog; + U16 fastMode; +} FSE_DTableHeader; /* sizeof U32 */ + +typedef struct { + unsigned short newState; + unsigned char symbol; + unsigned char nbBits; +} FSE_decode_t; /* size == U32 */ + +ZSTD_STATIC void FSE_initDState(FSE_DState_t *DStatePtr, BIT_DStream_t *bitD, const FSE_DTable *dt) +{ + const void *ptr = dt; + const FSE_DTableHeader *const DTableH = (const FSE_DTableHeader *)ptr; + DStatePtr->state = BIT_readBits(bitD, DTableH->tableLog); + BIT_reloadDStream(bitD); + DStatePtr->table = dt + 1; +} + +ZSTD_STATIC BYTE FSE_peekSymbol(const FSE_DState_t *DStatePtr) +{ + FSE_decode_t const DInfo = ((const FSE_decode_t *)(DStatePtr->table))[DStatePtr->state]; + return DInfo.symbol; +} + +ZSTD_STATIC void FSE_updateState(FSE_DState_t *DStatePtr, BIT_DStream_t *bitD) +{ + FSE_decode_t const DInfo = ((const FSE_decode_t *)(DStatePtr->table))[DStatePtr->state]; + U32 const nbBits = DInfo.nbBits; + size_t const lowBits = BIT_readBits(bitD, nbBits); + DStatePtr->state = DInfo.newState + lowBits; +} + +ZSTD_STATIC BYTE FSE_decodeSymbol(FSE_DState_t *DStatePtr, BIT_DStream_t *bitD) +{ + FSE_decode_t const DInfo = ((const FSE_decode_t *)(DStatePtr->table))[DStatePtr->state]; + U32 const nbBits = DInfo.nbBits; + BYTE const symbol = DInfo.symbol; + size_t const lowBits = BIT_readBits(bitD, nbBits); + + DStatePtr->state = DInfo.newState + lowBits; + return symbol; +} + +/*! FSE_decodeSymbolFast() : + unsafe, only works if no symbol has a probability > 50% */ +ZSTD_STATIC BYTE FSE_decodeSymbolFast(FSE_DState_t *DStatePtr, BIT_DStream_t *bitD) +{ + FSE_decode_t const DInfo = ((const FSE_decode_t *)(DStatePtr->table))[DStatePtr->state]; + U32 const nbBits = DInfo.nbBits; + BYTE const symbol = DInfo.symbol; + size_t const lowBits = BIT_readBitsFast(bitD, nbBits); + + DStatePtr->state = DInfo.newState + lowBits; + return symbol; +} + +ZSTD_STATIC unsigned FSE_endOfDState(const FSE_DState_t *DStatePtr) { return DStatePtr->state == 0; } + +/* ************************************************************** +* Tuning parameters +****************************************************************/ +/*!MEMORY_USAGE : +* Memory usage formula : N->2^N Bytes (examples : 10 -> 1KB; 12 -> 4KB ; 16 -> 64KB; 20 -> 1MB; etc.) +* Increasing memory usage improves compression ratio +* Reduced memory usage can improve speed, due to cache effect +* Recommended max value is 14, for 16KB, which nicely fits into Intel x86 L1 cache */ +#ifndef FSE_MAX_MEMORY_USAGE +#define FSE_MAX_MEMORY_USAGE 14 +#endif +#ifndef FSE_DEFAULT_MEMORY_USAGE +#define FSE_DEFAULT_MEMORY_USAGE 13 +#endif + +/*!FSE_MAX_SYMBOL_VALUE : +* Maximum symbol value authorized. +* Required for proper stack allocation */ +#ifndef FSE_MAX_SYMBOL_VALUE +#define FSE_MAX_SYMBOL_VALUE 255 +#endif + +/* ************************************************************** +* template functions type & suffix +****************************************************************/ +#define FSE_FUNCTION_TYPE BYTE +#define FSE_FUNCTION_EXTENSION +#define FSE_DECODE_TYPE FSE_decode_t + +/* *************************************************************** +* Constants +*****************************************************************/ +#define FSE_MAX_TABLELOG (FSE_MAX_MEMORY_USAGE - 2) +#define FSE_MAX_TABLESIZE (1U << FSE_MAX_TABLELOG) +#define FSE_MAXTABLESIZE_MASK (FSE_MAX_TABLESIZE - 1) +#define FSE_DEFAULT_TABLELOG (FSE_DEFAULT_MEMORY_USAGE - 2) +#define FSE_MIN_TABLELOG 5 + +#define FSE_TABLELOG_ABSOLUTE_MAX 15 +#if FSE_MAX_TABLELOG > FSE_TABLELOG_ABSOLUTE_MAX +#error "FSE_MAX_TABLELOG > FSE_TABLELOG_ABSOLUTE_MAX is not supported" +#endif + +#define FSE_TABLESTEP(tableSize) ((tableSize >> 1) + (tableSize >> 3) + 3) + +#endif /* FSE_H */ diff --git a/lib/zstd/fse_compress.c b/lib/zstd/fse_compress.c new file mode 100644 index 000000000000..ef3d1741d532 --- /dev/null +++ b/lib/zstd/fse_compress.c @@ -0,0 +1,795 @@ +/* + * FSE : Finite State Entropy encoder + * Copyright (C) 2013-2015, Yann Collet. + * + * BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php) + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * This program is free software; you can redistribute it and/or modify it under + * the terms of the GNU General Public License version 2 as published by the + * Free Software Foundation. This program is dual-licensed; you may select + * either version 2 of the GNU General Public License ("GPL") or BSD license + * ("BSD"). + * + * You can contact the author at : + * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy + */ + +/* ************************************************************** +* Compiler specifics +****************************************************************/ +#define FORCE_INLINE static __always_inline + +/* ************************************************************** +* Includes +****************************************************************/ +#include "bitstream.h" +#include "fse.h" +#include +#include +#include +#include /* memcpy, memset */ + +/* ************************************************************** +* Error Management +****************************************************************/ +#define FSE_STATIC_ASSERT(c) \ + { \ + enum { FSE_static_assert = 1 / (int)(!!(c)) }; \ + } /* use only *after* variable declarations */ + +/* ************************************************************** +* Templates +****************************************************************/ +/* + designed to be included + for type-specific functions (template emulation in C) + Objective is to write these functions only once, for improved maintenance +*/ + +/* safety checks */ +#ifndef FSE_FUNCTION_EXTENSION +#error "FSE_FUNCTION_EXTENSION must be defined" +#endif +#ifndef FSE_FUNCTION_TYPE +#error "FSE_FUNCTION_TYPE must be defined" +#endif + +/* Function names */ +#define FSE_CAT(X, Y) X##Y +#define FSE_FUNCTION_NAME(X, Y) FSE_CAT(X, Y) +#define FSE_TYPE_NAME(X, Y) FSE_CAT(X, Y) + +/* Function templates */ + +/* FSE_buildCTable_wksp() : + * Same as FSE_buildCTable(), but using an externally allocated scratch buffer (`workSpace`). + * wkspSize should be sized to handle worst case situation, which is `1<> 1 : 1); + FSE_symbolCompressionTransform *const symbolTT = (FSE_symbolCompressionTransform *)(FSCT); + U32 const step = FSE_TABLESTEP(tableSize); + U32 highThreshold = tableSize - 1; + + U32 *cumul; + FSE_FUNCTION_TYPE *tableSymbol; + size_t spaceUsed32 = 0; + + cumul = (U32 *)workspace + spaceUsed32; + spaceUsed32 += FSE_MAX_SYMBOL_VALUE + 2; + tableSymbol = (FSE_FUNCTION_TYPE *)((U32 *)workspace + spaceUsed32); + spaceUsed32 += ALIGN(sizeof(FSE_FUNCTION_TYPE) * ((size_t)1 << tableLog), sizeof(U32)) >> 2; + + if ((spaceUsed32 << 2) > workspaceSize) + return ERROR(tableLog_tooLarge); + workspace = (U32 *)workspace + spaceUsed32; + workspaceSize -= (spaceUsed32 << 2); + + /* CTable header */ + tableU16[-2] = (U16)tableLog; + tableU16[-1] = (U16)maxSymbolValue; + + /* For explanations on how to distribute symbol values over the table : + * http://fastcompression.blogspot.fr/2014/02/fse-distributing-symbol-values.html */ + + /* symbol start positions */ + { + U32 u; + cumul[0] = 0; + for (u = 1; u <= maxSymbolValue + 1; u++) { + if (normalizedCounter[u - 1] == -1) { /* Low proba symbol */ + cumul[u] = cumul[u - 1] + 1; + tableSymbol[highThreshold--] = (FSE_FUNCTION_TYPE)(u - 1); + } else { + cumul[u] = cumul[u - 1] + normalizedCounter[u - 1]; + } + } + cumul[maxSymbolValue + 1] = tableSize + 1; + } + + /* Spread symbols */ + { + U32 position = 0; + U32 symbol; + for (symbol = 0; symbol <= maxSymbolValue; symbol++) { + int nbOccurences; + for (nbOccurences = 0; nbOccurences < normalizedCounter[symbol]; nbOccurences++) { + tableSymbol[position] = (FSE_FUNCTION_TYPE)symbol; + position = (position + step) & tableMask; + while (position > highThreshold) + position = (position + step) & tableMask; /* Low proba area */ + } + } + + if (position != 0) + return ERROR(GENERIC); /* Must have gone through all positions */ + } + + /* Build table */ + { + U32 u; + for (u = 0; u < tableSize; u++) { + FSE_FUNCTION_TYPE s = tableSymbol[u]; /* note : static analyzer may not understand tableSymbol is properly initialized */ + tableU16[cumul[s]++] = (U16)(tableSize + u); /* TableU16 : sorted by symbol order; gives next state value */ + } + } + + /* Build Symbol Transformation Table */ + { + unsigned total = 0; + unsigned s; + for (s = 0; s <= maxSymbolValue; s++) { + switch (normalizedCounter[s]) { + case 0: break; + + case -1: + case 1: + symbolTT[s].deltaNbBits = (tableLog << 16) - (1 << tableLog); + symbolTT[s].deltaFindState = total - 1; + total++; + break; + default: { + U32 const maxBitsOut = tableLog - BIT_highbit32(normalizedCounter[s] - 1); + U32 const minStatePlus = normalizedCounter[s] << maxBitsOut; + symbolTT[s].deltaNbBits = (maxBitsOut << 16) - minStatePlus; + symbolTT[s].deltaFindState = total - normalizedCounter[s]; + total += normalizedCounter[s]; + } + } + } + } + + return 0; +} + +/*-************************************************************** +* FSE NCount encoding-decoding +****************************************************************/ +size_t FSE_NCountWriteBound(unsigned maxSymbolValue, unsigned tableLog) +{ + size_t const maxHeaderSize = (((maxSymbolValue + 1) * tableLog) >> 3) + 3; + return maxSymbolValue ? maxHeaderSize : FSE_NCOUNTBOUND; /* maxSymbolValue==0 ? use default */ +} + +static size_t FSE_writeNCount_generic(void *header, size_t headerBufferSize, const short *normalizedCounter, unsigned maxSymbolValue, unsigned tableLog, + unsigned writeIsSafe) +{ + BYTE *const ostart = (BYTE *)header; + BYTE *out = ostart; + BYTE *const oend = ostart + headerBufferSize; + int nbBits; + const int tableSize = 1 << tableLog; + int remaining; + int threshold; + U32 bitStream; + int bitCount; + unsigned charnum = 0; + int previous0 = 0; + + bitStream = 0; + bitCount = 0; + /* Table Size */ + bitStream += (tableLog - FSE_MIN_TABLELOG) << bitCount; + bitCount += 4; + + /* Init */ + remaining = tableSize + 1; /* +1 for extra accuracy */ + threshold = tableSize; + nbBits = tableLog + 1; + + while (remaining > 1) { /* stops at 1 */ + if (previous0) { + unsigned start = charnum; + while (!normalizedCounter[charnum]) + charnum++; + while (charnum >= start + 24) { + start += 24; + bitStream += 0xFFFFU << bitCount; + if ((!writeIsSafe) && (out > oend - 2)) + return ERROR(dstSize_tooSmall); /* Buffer overflow */ + out[0] = (BYTE)bitStream; + out[1] = (BYTE)(bitStream >> 8); + out += 2; + bitStream >>= 16; + } + while (charnum >= start + 3) { + start += 3; + bitStream += 3 << bitCount; + bitCount += 2; + } + bitStream += (charnum - start) << bitCount; + bitCount += 2; + if (bitCount > 16) { + if ((!writeIsSafe) && (out > oend - 2)) + return ERROR(dstSize_tooSmall); /* Buffer overflow */ + out[0] = (BYTE)bitStream; + out[1] = (BYTE)(bitStream >> 8); + out += 2; + bitStream >>= 16; + bitCount -= 16; + } + } + { + int count = normalizedCounter[charnum++]; + int const max = (2 * threshold - 1) - remaining; + remaining -= count < 0 ? -count : count; + count++; /* +1 for extra accuracy */ + if (count >= threshold) + count += max; /* [0..max[ [max..threshold[ (...) [threshold+max 2*threshold[ */ + bitStream += count << bitCount; + bitCount += nbBits; + bitCount -= (count < max); + previous0 = (count == 1); + if (remaining < 1) + return ERROR(GENERIC); + while (remaining < threshold) + nbBits--, threshold >>= 1; + } + if (bitCount > 16) { + if ((!writeIsSafe) && (out > oend - 2)) + return ERROR(dstSize_tooSmall); /* Buffer overflow */ + out[0] = (BYTE)bitStream; + out[1] = (BYTE)(bitStream >> 8); + out += 2; + bitStream >>= 16; + bitCount -= 16; + } + } + + /* flush remaining bitStream */ + if ((!writeIsSafe) && (out > oend - 2)) + return ERROR(dstSize_tooSmall); /* Buffer overflow */ + out[0] = (BYTE)bitStream; + out[1] = (BYTE)(bitStream >> 8); + out += (bitCount + 7) / 8; + + if (charnum > maxSymbolValue + 1) + return ERROR(GENERIC); + + return (out - ostart); +} + +size_t FSE_writeNCount(void *buffer, size_t bufferSize, const short *normalizedCounter, unsigned maxSymbolValue, unsigned tableLog) +{ + if (tableLog > FSE_MAX_TABLELOG) + return ERROR(tableLog_tooLarge); /* Unsupported */ + if (tableLog < FSE_MIN_TABLELOG) + return ERROR(GENERIC); /* Unsupported */ + + if (bufferSize < FSE_NCountWriteBound(maxSymbolValue, tableLog)) + return FSE_writeNCount_generic(buffer, bufferSize, normalizedCounter, maxSymbolValue, tableLog, 0); + + return FSE_writeNCount_generic(buffer, bufferSize, normalizedCounter, maxSymbolValue, tableLog, 1); +} + +/*-************************************************************** +* Counting histogram +****************************************************************/ +/*! FSE_count_simple + This function counts byte values within `src`, and store the histogram into table `count`. + It doesn't use any additional memory. + But this function is unsafe : it doesn't check that all values within `src` can fit into `count`. + For this reason, prefer using a table `count` with 256 elements. + @return : count of most numerous element +*/ +size_t FSE_count_simple(unsigned *count, unsigned *maxSymbolValuePtr, const void *src, size_t srcSize) +{ + const BYTE *ip = (const BYTE *)src; + const BYTE *const end = ip + srcSize; + unsigned maxSymbolValue = *maxSymbolValuePtr; + unsigned max = 0; + + memset(count, 0, (maxSymbolValue + 1) * sizeof(*count)); + if (srcSize == 0) { + *maxSymbolValuePtr = 0; + return 0; + } + + while (ip < end) + count[*ip++]++; + + while (!count[maxSymbolValue]) + maxSymbolValue--; + *maxSymbolValuePtr = maxSymbolValue; + + { + U32 s; + for (s = 0; s <= maxSymbolValue; s++) + if (count[s] > max) + max = count[s]; + } + + return (size_t)max; +} + +/* FSE_count_parallel_wksp() : + * Same as FSE_count_parallel(), but using an externally provided scratch buffer. + * `workSpace` size must be a minimum of `1024 * sizeof(unsigned)`` */ +static size_t FSE_count_parallel_wksp(unsigned *count, unsigned *maxSymbolValuePtr, const void *source, size_t sourceSize, unsigned checkMax, + unsigned *const workSpace) +{ + const BYTE *ip = (const BYTE *)source; + const BYTE *const iend = ip + sourceSize; + unsigned maxSymbolValue = *maxSymbolValuePtr; + unsigned max = 0; + U32 *const Counting1 = workSpace; + U32 *const Counting2 = Counting1 + 256; + U32 *const Counting3 = Counting2 + 256; + U32 *const Counting4 = Counting3 + 256; + + memset(Counting1, 0, 4 * 256 * sizeof(unsigned)); + + /* safety checks */ + if (!sourceSize) { + memset(count, 0, maxSymbolValue + 1); + *maxSymbolValuePtr = 0; + return 0; + } + if (!maxSymbolValue) + maxSymbolValue = 255; /* 0 == default */ + + /* by stripes of 16 bytes */ + { + U32 cached = ZSTD_read32(ip); + ip += 4; + while (ip < iend - 15) { + U32 c = cached; + cached = ZSTD_read32(ip); + ip += 4; + Counting1[(BYTE)c]++; + Counting2[(BYTE)(c >> 8)]++; + Counting3[(BYTE)(c >> 16)]++; + Counting4[c >> 24]++; + c = cached; + cached = ZSTD_read32(ip); + ip += 4; + Counting1[(BYTE)c]++; + Counting2[(BYTE)(c >> 8)]++; + Counting3[(BYTE)(c >> 16)]++; + Counting4[c >> 24]++; + c = cached; + cached = ZSTD_read32(ip); + ip += 4; + Counting1[(BYTE)c]++; + Counting2[(BYTE)(c >> 8)]++; + Counting3[(BYTE)(c >> 16)]++; + Counting4[c >> 24]++; + c = cached; + cached = ZSTD_read32(ip); + ip += 4; + Counting1[(BYTE)c]++; + Counting2[(BYTE)(c >> 8)]++; + Counting3[(BYTE)(c >> 16)]++; + Counting4[c >> 24]++; + } + ip -= 4; + } + + /* finish last symbols */ + while (ip < iend) + Counting1[*ip++]++; + + if (checkMax) { /* verify stats will fit into destination table */ + U32 s; + for (s = 255; s > maxSymbolValue; s--) { + Counting1[s] += Counting2[s] + Counting3[s] + Counting4[s]; + if (Counting1[s]) + return ERROR(maxSymbolValue_tooSmall); + } + } + + { + U32 s; + for (s = 0; s <= maxSymbolValue; s++) { + count[s] = Counting1[s] + Counting2[s] + Counting3[s] + Counting4[s]; + if (count[s] > max) + max = count[s]; + } + } + + while (!count[maxSymbolValue]) + maxSymbolValue--; + *maxSymbolValuePtr = maxSymbolValue; + return (size_t)max; +} + +/* FSE_countFast_wksp() : + * Same as FSE_countFast(), but using an externally provided scratch buffer. + * `workSpace` size must be table of >= `1024` unsigned */ +size_t FSE_countFast_wksp(unsigned *count, unsigned *maxSymbolValuePtr, const void *source, size_t sourceSize, unsigned *workSpace) +{ + if (sourceSize < 1500) + return FSE_count_simple(count, maxSymbolValuePtr, source, sourceSize); + return FSE_count_parallel_wksp(count, maxSymbolValuePtr, source, sourceSize, 0, workSpace); +} + +/* FSE_count_wksp() : + * Same as FSE_count(), but using an externally provided scratch buffer. + * `workSpace` size must be table of >= `1024` unsigned */ +size_t FSE_count_wksp(unsigned *count, unsigned *maxSymbolValuePtr, const void *source, size_t sourceSize, unsigned *workSpace) +{ + if (*maxSymbolValuePtr < 255) + return FSE_count_parallel_wksp(count, maxSymbolValuePtr, source, sourceSize, 1, workSpace); + *maxSymbolValuePtr = 255; + return FSE_countFast_wksp(count, maxSymbolValuePtr, source, sourceSize, workSpace); +} + +/*-************************************************************** +* FSE Compression Code +****************************************************************/ +/*! FSE_sizeof_CTable() : + FSE_CTable is a variable size structure which contains : + `U16 tableLog;` + `U16 maxSymbolValue;` + `U16 nextStateNumber[1 << tableLog];` // This size is variable + `FSE_symbolCompressionTransform symbolTT[maxSymbolValue+1];` // This size is variable +Allocation is manual (C standard does not support variable-size structures). +*/ +size_t FSE_sizeof_CTable(unsigned maxSymbolValue, unsigned tableLog) +{ + if (tableLog > FSE_MAX_TABLELOG) + return ERROR(tableLog_tooLarge); + return FSE_CTABLE_SIZE_U32(tableLog, maxSymbolValue) * sizeof(U32); +} + +/* provides the minimum logSize to safely represent a distribution */ +static unsigned FSE_minTableLog(size_t srcSize, unsigned maxSymbolValue) +{ + U32 minBitsSrc = BIT_highbit32((U32)(srcSize - 1)) + 1; + U32 minBitsSymbols = BIT_highbit32(maxSymbolValue) + 2; + U32 minBits = minBitsSrc < minBitsSymbols ? minBitsSrc : minBitsSymbols; + return minBits; +} + +unsigned FSE_optimalTableLog_internal(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue, unsigned minus) +{ + U32 maxBitsSrc = BIT_highbit32((U32)(srcSize - 1)) - minus; + U32 tableLog = maxTableLog; + U32 minBits = FSE_minTableLog(srcSize, maxSymbolValue); + if (tableLog == 0) + tableLog = FSE_DEFAULT_TABLELOG; + if (maxBitsSrc < tableLog) + tableLog = maxBitsSrc; /* Accuracy can be reduced */ + if (minBits > tableLog) + tableLog = minBits; /* Need a minimum to safely represent all symbol values */ + if (tableLog < FSE_MIN_TABLELOG) + tableLog = FSE_MIN_TABLELOG; + if (tableLog > FSE_MAX_TABLELOG) + tableLog = FSE_MAX_TABLELOG; + return tableLog; +} + +unsigned FSE_optimalTableLog(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue) +{ + return FSE_optimalTableLog_internal(maxTableLog, srcSize, maxSymbolValue, 2); +} + +/* Secondary normalization method. + To be used when primary method fails. */ + +static size_t FSE_normalizeM2(short *norm, U32 tableLog, const unsigned *count, size_t total, U32 maxSymbolValue) +{ + short const NOT_YET_ASSIGNED = -2; + U32 s; + U32 distributed = 0; + U32 ToDistribute; + + /* Init */ + U32 const lowThreshold = (U32)(total >> tableLog); + U32 lowOne = (U32)((total * 3) >> (tableLog + 1)); + + for (s = 0; s <= maxSymbolValue; s++) { + if (count[s] == 0) { + norm[s] = 0; + continue; + } + if (count[s] <= lowThreshold) { + norm[s] = -1; + distributed++; + total -= count[s]; + continue; + } + if (count[s] <= lowOne) { + norm[s] = 1; + distributed++; + total -= count[s]; + continue; + } + + norm[s] = NOT_YET_ASSIGNED; + } + ToDistribute = (1 << tableLog) - distributed; + + if ((total / ToDistribute) > lowOne) { + /* risk of rounding to zero */ + lowOne = (U32)((total * 3) / (ToDistribute * 2)); + for (s = 0; s <= maxSymbolValue; s++) { + if ((norm[s] == NOT_YET_ASSIGNED) && (count[s] <= lowOne)) { + norm[s] = 1; + distributed++; + total -= count[s]; + continue; + } + } + ToDistribute = (1 << tableLog) - distributed; + } + + if (distributed == maxSymbolValue + 1) { + /* all values are pretty poor; + probably incompressible data (should have already been detected); + find max, then give all remaining points to max */ + U32 maxV = 0, maxC = 0; + for (s = 0; s <= maxSymbolValue; s++) + if (count[s] > maxC) + maxV = s, maxC = count[s]; + norm[maxV] += (short)ToDistribute; + return 0; + } + + if (total == 0) { + /* all of the symbols were low enough for the lowOne or lowThreshold */ + for (s = 0; ToDistribute > 0; s = (s + 1) % (maxSymbolValue + 1)) + if (norm[s] > 0) + ToDistribute--, norm[s]++; + return 0; + } + + { + U64 const vStepLog = 62 - tableLog; + U64 const mid = (1ULL << (vStepLog - 1)) - 1; + U64 const rStep = div_u64((((U64)1 << vStepLog) * ToDistribute) + mid, (U32)total); /* scale on remaining */ + U64 tmpTotal = mid; + for (s = 0; s <= maxSymbolValue; s++) { + if (norm[s] == NOT_YET_ASSIGNED) { + U64 const end = tmpTotal + (count[s] * rStep); + U32 const sStart = (U32)(tmpTotal >> vStepLog); + U32 const sEnd = (U32)(end >> vStepLog); + U32 const weight = sEnd - sStart; + if (weight < 1) + return ERROR(GENERIC); + norm[s] = (short)weight; + tmpTotal = end; + } + } + } + + return 0; +} + +size_t FSE_normalizeCount(short *normalizedCounter, unsigned tableLog, const unsigned *count, size_t total, unsigned maxSymbolValue) +{ + /* Sanity checks */ + if (tableLog == 0) + tableLog = FSE_DEFAULT_TABLELOG; + if (tableLog < FSE_MIN_TABLELOG) + return ERROR(GENERIC); /* Unsupported size */ + if (tableLog > FSE_MAX_TABLELOG) + return ERROR(tableLog_tooLarge); /* Unsupported size */ + if (tableLog < FSE_minTableLog(total, maxSymbolValue)) + return ERROR(GENERIC); /* Too small tableLog, compression potentially impossible */ + + { + U32 const rtbTable[] = {0, 473195, 504333, 520860, 550000, 700000, 750000, 830000}; + U64 const scale = 62 - tableLog; + U64 const step = div_u64((U64)1 << 62, (U32)total); /* <== here, one division ! */ + U64 const vStep = 1ULL << (scale - 20); + int stillToDistribute = 1 << tableLog; + unsigned s; + unsigned largest = 0; + short largestP = 0; + U32 lowThreshold = (U32)(total >> tableLog); + + for (s = 0; s <= maxSymbolValue; s++) { + if (count[s] == total) + return 0; /* rle special case */ + if (count[s] == 0) { + normalizedCounter[s] = 0; + continue; + } + if (count[s] <= lowThreshold) { + normalizedCounter[s] = -1; + stillToDistribute--; + } else { + short proba = (short)((count[s] * step) >> scale); + if (proba < 8) { + U64 restToBeat = vStep * rtbTable[proba]; + proba += (count[s] * step) - ((U64)proba << scale) > restToBeat; + } + if (proba > largestP) + largestP = proba, largest = s; + normalizedCounter[s] = proba; + stillToDistribute -= proba; + } + } + if (-stillToDistribute >= (normalizedCounter[largest] >> 1)) { + /* corner case, need another normalization method */ + size_t const errorCode = FSE_normalizeM2(normalizedCounter, tableLog, count, total, maxSymbolValue); + if (FSE_isError(errorCode)) + return errorCode; + } else + normalizedCounter[largest] += (short)stillToDistribute; + } + + return tableLog; +} + +/* fake FSE_CTable, for raw (uncompressed) input */ +size_t FSE_buildCTable_raw(FSE_CTable *ct, unsigned nbBits) +{ + const unsigned tableSize = 1 << nbBits; + const unsigned tableMask = tableSize - 1; + const unsigned maxSymbolValue = tableMask; + void *const ptr = ct; + U16 *const tableU16 = ((U16 *)ptr) + 2; + void *const FSCT = ((U32 *)ptr) + 1 /* header */ + (tableSize >> 1); /* assumption : tableLog >= 1 */ + FSE_symbolCompressionTransform *const symbolTT = (FSE_symbolCompressionTransform *)(FSCT); + unsigned s; + + /* Sanity checks */ + if (nbBits < 1) + return ERROR(GENERIC); /* min size */ + + /* header */ + tableU16[-2] = (U16)nbBits; + tableU16[-1] = (U16)maxSymbolValue; + + /* Build table */ + for (s = 0; s < tableSize; s++) + tableU16[s] = (U16)(tableSize + s); + + /* Build Symbol Transformation Table */ + { + const U32 deltaNbBits = (nbBits << 16) - (1 << nbBits); + for (s = 0; s <= maxSymbolValue; s++) { + symbolTT[s].deltaNbBits = deltaNbBits; + symbolTT[s].deltaFindState = s - 1; + } + } + + return 0; +} + +/* fake FSE_CTable, for rle input (always same symbol) */ +size_t FSE_buildCTable_rle(FSE_CTable *ct, BYTE symbolValue) +{ + void *ptr = ct; + U16 *tableU16 = ((U16 *)ptr) + 2; + void *FSCTptr = (U32 *)ptr + 2; + FSE_symbolCompressionTransform *symbolTT = (FSE_symbolCompressionTransform *)FSCTptr; + + /* header */ + tableU16[-2] = (U16)0; + tableU16[-1] = (U16)symbolValue; + + /* Build table */ + tableU16[0] = 0; + tableU16[1] = 0; /* just in case */ + + /* Build Symbol Transformation Table */ + symbolTT[symbolValue].deltaNbBits = 0; + symbolTT[symbolValue].deltaFindState = 0; + + return 0; +} + +static size_t FSE_compress_usingCTable_generic(void *dst, size_t dstSize, const void *src, size_t srcSize, const FSE_CTable *ct, const unsigned fast) +{ + const BYTE *const istart = (const BYTE *)src; + const BYTE *const iend = istart + srcSize; + const BYTE *ip = iend; + + BIT_CStream_t bitC; + FSE_CState_t CState1, CState2; + + /* init */ + if (srcSize <= 2) + return 0; + { + size_t const initError = BIT_initCStream(&bitC, dst, dstSize); + if (FSE_isError(initError)) + return 0; /* not enough space available to write a bitstream */ + } + +#define FSE_FLUSHBITS(s) (fast ? BIT_flushBitsFast(s) : BIT_flushBits(s)) + + if (srcSize & 1) { + FSE_initCState2(&CState1, ct, *--ip); + FSE_initCState2(&CState2, ct, *--ip); + FSE_encodeSymbol(&bitC, &CState1, *--ip); + FSE_FLUSHBITS(&bitC); + } else { + FSE_initCState2(&CState2, ct, *--ip); + FSE_initCState2(&CState1, ct, *--ip); + } + + /* join to mod 4 */ + srcSize -= 2; + if ((sizeof(bitC.bitContainer) * 8 > FSE_MAX_TABLELOG * 4 + 7) && (srcSize & 2)) { /* test bit 2 */ + FSE_encodeSymbol(&bitC, &CState2, *--ip); + FSE_encodeSymbol(&bitC, &CState1, *--ip); + FSE_FLUSHBITS(&bitC); + } + + /* 2 or 4 encoding per loop */ + while (ip > istart) { + + FSE_encodeSymbol(&bitC, &CState2, *--ip); + + if (sizeof(bitC.bitContainer) * 8 < FSE_MAX_TABLELOG * 2 + 7) /* this test must be static */ + FSE_FLUSHBITS(&bitC); + + FSE_encodeSymbol(&bitC, &CState1, *--ip); + + if (sizeof(bitC.bitContainer) * 8 > FSE_MAX_TABLELOG * 4 + 7) { /* this test must be static */ + FSE_encodeSymbol(&bitC, &CState2, *--ip); + FSE_encodeSymbol(&bitC, &CState1, *--ip); + } + + FSE_FLUSHBITS(&bitC); + } + + FSE_flushCState(&bitC, &CState2); + FSE_flushCState(&bitC, &CState1); + return BIT_closeCStream(&bitC); +} + +size_t FSE_compress_usingCTable(void *dst, size_t dstSize, const void *src, size_t srcSize, const FSE_CTable *ct) +{ + unsigned const fast = (dstSize >= FSE_BLOCKBOUND(srcSize)); + + if (fast) + return FSE_compress_usingCTable_generic(dst, dstSize, src, srcSize, ct, 1); + else + return FSE_compress_usingCTable_generic(dst, dstSize, src, srcSize, ct, 0); +} + +size_t FSE_compressBound(size_t size) { return FSE_COMPRESSBOUND(size); } diff --git a/lib/zstd/fse_decompress.c b/lib/zstd/fse_decompress.c new file mode 100644 index 000000000000..a84300e5a013 --- /dev/null +++ b/lib/zstd/fse_decompress.c @@ -0,0 +1,332 @@ +/* + * FSE : Finite State Entropy decoder + * Copyright (C) 2013-2015, Yann Collet. + * + * BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php) + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * This program is free software; you can redistribute it and/or modify it under + * the terms of the GNU General Public License version 2 as published by the + * Free Software Foundation. This program is dual-licensed; you may select + * either version 2 of the GNU General Public License ("GPL") or BSD license + * ("BSD"). + * + * You can contact the author at : + * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy + */ + +/* ************************************************************** +* Compiler specifics +****************************************************************/ +#define FORCE_INLINE static __always_inline + +/* ************************************************************** +* Includes +****************************************************************/ +#include "bitstream.h" +#include "fse.h" +#include +#include +#include /* memcpy, memset */ + +/* ************************************************************** +* Error Management +****************************************************************/ +#define FSE_isError ERR_isError +#define FSE_STATIC_ASSERT(c) \ + { \ + enum { FSE_static_assert = 1 / (int)(!!(c)) }; \ + } /* use only *after* variable declarations */ + +/* check and forward error code */ +#define CHECK_F(f) \ + { \ + size_t const e = f; \ + if (FSE_isError(e)) \ + return e; \ + } + +/* ************************************************************** +* Templates +****************************************************************/ +/* + designed to be included + for type-specific functions (template emulation in C) + Objective is to write these functions only once, for improved maintenance +*/ + +/* safety checks */ +#ifndef FSE_FUNCTION_EXTENSION +#error "FSE_FUNCTION_EXTENSION must be defined" +#endif +#ifndef FSE_FUNCTION_TYPE +#error "FSE_FUNCTION_TYPE must be defined" +#endif + +/* Function names */ +#define FSE_CAT(X, Y) X##Y +#define FSE_FUNCTION_NAME(X, Y) FSE_CAT(X, Y) +#define FSE_TYPE_NAME(X, Y) FSE_CAT(X, Y) + +/* Function templates */ + +size_t FSE_buildDTable_wksp(FSE_DTable *dt, const short *normalizedCounter, unsigned maxSymbolValue, unsigned tableLog, void *workspace, size_t workspaceSize) +{ + void *const tdPtr = dt + 1; /* because *dt is unsigned, 32-bits aligned on 32-bits */ + FSE_DECODE_TYPE *const tableDecode = (FSE_DECODE_TYPE *)(tdPtr); + U16 *symbolNext = (U16 *)workspace; + + U32 const maxSV1 = maxSymbolValue + 1; + U32 const tableSize = 1 << tableLog; + U32 highThreshold = tableSize - 1; + + /* Sanity Checks */ + if (workspaceSize < sizeof(U16) * (FSE_MAX_SYMBOL_VALUE + 1)) + return ERROR(tableLog_tooLarge); + if (maxSymbolValue > FSE_MAX_SYMBOL_VALUE) + return ERROR(maxSymbolValue_tooLarge); + if (tableLog > FSE_MAX_TABLELOG) + return ERROR(tableLog_tooLarge); + + /* Init, lay down lowprob symbols */ + { + FSE_DTableHeader DTableH; + DTableH.tableLog = (U16)tableLog; + DTableH.fastMode = 1; + { + S16 const largeLimit = (S16)(1 << (tableLog - 1)); + U32 s; + for (s = 0; s < maxSV1; s++) { + if (normalizedCounter[s] == -1) { + tableDecode[highThreshold--].symbol = (FSE_FUNCTION_TYPE)s; + symbolNext[s] = 1; + } else { + if (normalizedCounter[s] >= largeLimit) + DTableH.fastMode = 0; + symbolNext[s] = normalizedCounter[s]; + } + } + } + memcpy(dt, &DTableH, sizeof(DTableH)); + } + + /* Spread symbols */ + { + U32 const tableMask = tableSize - 1; + U32 const step = FSE_TABLESTEP(tableSize); + U32 s, position = 0; + for (s = 0; s < maxSV1; s++) { + int i; + for (i = 0; i < normalizedCounter[s]; i++) { + tableDecode[position].symbol = (FSE_FUNCTION_TYPE)s; + position = (position + step) & tableMask; + while (position > highThreshold) + position = (position + step) & tableMask; /* lowprob area */ + } + } + if (position != 0) + return ERROR(GENERIC); /* position must reach all cells once, otherwise normalizedCounter is incorrect */ + } + + /* Build Decoding table */ + { + U32 u; + for (u = 0; u < tableSize; u++) { + FSE_FUNCTION_TYPE const symbol = (FSE_FUNCTION_TYPE)(tableDecode[u].symbol); + U16 nextState = symbolNext[symbol]++; + tableDecode[u].nbBits = (BYTE)(tableLog - BIT_highbit32((U32)nextState)); + tableDecode[u].newState = (U16)((nextState << tableDecode[u].nbBits) - tableSize); + } + } + + return 0; +} + +/*-******************************************************* +* Decompression (Byte symbols) +*********************************************************/ +size_t FSE_buildDTable_rle(FSE_DTable *dt, BYTE symbolValue) +{ + void *ptr = dt; + FSE_DTableHeader *const DTableH = (FSE_DTableHeader *)ptr; + void *dPtr = dt + 1; + FSE_decode_t *const cell = (FSE_decode_t *)dPtr; + + DTableH->tableLog = 0; + DTableH->fastMode = 0; + + cell->newState = 0; + cell->symbol = symbolValue; + cell->nbBits = 0; + + return 0; +} + +size_t FSE_buildDTable_raw(FSE_DTable *dt, unsigned nbBits) +{ + void *ptr = dt; + FSE_DTableHeader *const DTableH = (FSE_DTableHeader *)ptr; + void *dPtr = dt + 1; + FSE_decode_t *const dinfo = (FSE_decode_t *)dPtr; + const unsigned tableSize = 1 << nbBits; + const unsigned tableMask = tableSize - 1; + const unsigned maxSV1 = tableMask + 1; + unsigned s; + + /* Sanity checks */ + if (nbBits < 1) + return ERROR(GENERIC); /* min size */ + + /* Build Decoding Table */ + DTableH->tableLog = (U16)nbBits; + DTableH->fastMode = 1; + for (s = 0; s < maxSV1; s++) { + dinfo[s].newState = 0; + dinfo[s].symbol = (BYTE)s; + dinfo[s].nbBits = (BYTE)nbBits; + } + + return 0; +} + +FORCE_INLINE size_t FSE_decompress_usingDTable_generic(void *dst, size_t maxDstSize, const void *cSrc, size_t cSrcSize, const FSE_DTable *dt, + const unsigned fast) +{ + BYTE *const ostart = (BYTE *)dst; + BYTE *op = ostart; + BYTE *const omax = op + maxDstSize; + BYTE *const olimit = omax - 3; + + BIT_DStream_t bitD; + FSE_DState_t state1; + FSE_DState_t state2; + + /* Init */ + CHECK_F(BIT_initDStream(&bitD, cSrc, cSrcSize)); + + FSE_initDState(&state1, &bitD, dt); + FSE_initDState(&state2, &bitD, dt); + +#define FSE_GETSYMBOL(statePtr) fast ? FSE_decodeSymbolFast(statePtr, &bitD) : FSE_decodeSymbol(statePtr, &bitD) + + /* 4 symbols per loop */ + for (; (BIT_reloadDStream(&bitD) == BIT_DStream_unfinished) & (op < olimit); op += 4) { + op[0] = FSE_GETSYMBOL(&state1); + + if (FSE_MAX_TABLELOG * 2 + 7 > sizeof(bitD.bitContainer) * 8) /* This test must be static */ + BIT_reloadDStream(&bitD); + + op[1] = FSE_GETSYMBOL(&state2); + + if (FSE_MAX_TABLELOG * 4 + 7 > sizeof(bitD.bitContainer) * 8) /* This test must be static */ + { + if (BIT_reloadDStream(&bitD) > BIT_DStream_unfinished) { + op += 2; + break; + } + } + + op[2] = FSE_GETSYMBOL(&state1); + + if (FSE_MAX_TABLELOG * 2 + 7 > sizeof(bitD.bitContainer) * 8) /* This test must be static */ + BIT_reloadDStream(&bitD); + + op[3] = FSE_GETSYMBOL(&state2); + } + + /* tail */ + /* note : BIT_reloadDStream(&bitD) >= FSE_DStream_partiallyFilled; Ends at exactly BIT_DStream_completed */ + while (1) { + if (op > (omax - 2)) + return ERROR(dstSize_tooSmall); + *op++ = FSE_GETSYMBOL(&state1); + if (BIT_reloadDStream(&bitD) == BIT_DStream_overflow) { + *op++ = FSE_GETSYMBOL(&state2); + break; + } + + if (op > (omax - 2)) + return ERROR(dstSize_tooSmall); + *op++ = FSE_GETSYMBOL(&state2); + if (BIT_reloadDStream(&bitD) == BIT_DStream_overflow) { + *op++ = FSE_GETSYMBOL(&state1); + break; + } + } + + return op - ostart; +} + +size_t FSE_decompress_usingDTable(void *dst, size_t originalSize, const void *cSrc, size_t cSrcSize, const FSE_DTable *dt) +{ + const void *ptr = dt; + const FSE_DTableHeader *DTableH = (const FSE_DTableHeader *)ptr; + const U32 fastMode = DTableH->fastMode; + + /* select fast mode (static) */ + if (fastMode) + return FSE_decompress_usingDTable_generic(dst, originalSize, cSrc, cSrcSize, dt, 1); + return FSE_decompress_usingDTable_generic(dst, originalSize, cSrc, cSrcSize, dt, 0); +} + +size_t FSE_decompress_wksp(void *dst, size_t dstCapacity, const void *cSrc, size_t cSrcSize, unsigned maxLog, void *workspace, size_t workspaceSize) +{ + const BYTE *const istart = (const BYTE *)cSrc; + const BYTE *ip = istart; + unsigned tableLog; + unsigned maxSymbolValue = FSE_MAX_SYMBOL_VALUE; + size_t NCountLength; + + FSE_DTable *dt; + short *counting; + size_t spaceUsed32 = 0; + + FSE_STATIC_ASSERT(sizeof(FSE_DTable) == sizeof(U32)); + + dt = (FSE_DTable *)((U32 *)workspace + spaceUsed32); + spaceUsed32 += FSE_DTABLE_SIZE_U32(maxLog); + counting = (short *)((U32 *)workspace + spaceUsed32); + spaceUsed32 += ALIGN(sizeof(short) * (FSE_MAX_SYMBOL_VALUE + 1), sizeof(U32)) >> 2; + + if ((spaceUsed32 << 2) > workspaceSize) + return ERROR(tableLog_tooLarge); + workspace = (U32 *)workspace + spaceUsed32; + workspaceSize -= (spaceUsed32 << 2); + + /* normal FSE decoding mode */ + NCountLength = FSE_readNCount(counting, &maxSymbolValue, &tableLog, istart, cSrcSize); + if (FSE_isError(NCountLength)) + return NCountLength; + // if (NCountLength >= cSrcSize) return ERROR(srcSize_wrong); /* too small input size; supposed to be already checked in NCountLength, only remaining + // case : NCountLength==cSrcSize */ + if (tableLog > maxLog) + return ERROR(tableLog_tooLarge); + ip += NCountLength; + cSrcSize -= NCountLength; + + CHECK_F(FSE_buildDTable_wksp(dt, counting, maxSymbolValue, tableLog, workspace, workspaceSize)); + + return FSE_decompress_usingDTable(dst, dstCapacity, ip, cSrcSize, dt); /* always return, even if it is an error code */ +} diff --git a/lib/zstd/huf.h b/lib/zstd/huf.h new file mode 100644 index 000000000000..2143da28d952 --- /dev/null +++ b/lib/zstd/huf.h @@ -0,0 +1,212 @@ +/* + * Huffman coder, part of New Generation Entropy library + * header file + * Copyright (C) 2013-2016, Yann Collet. + * + * BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php) + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * This program is free software; you can redistribute it and/or modify it under + * the terms of the GNU General Public License version 2 as published by the + * Free Software Foundation. This program is dual-licensed; you may select + * either version 2 of the GNU General Public License ("GPL") or BSD license + * ("BSD"). + * + * You can contact the author at : + * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy + */ +#ifndef HUF_H_298734234 +#define HUF_H_298734234 + +/* *** Dependencies *** */ +#include /* size_t */ + +/* *** Tool functions *** */ +#define HUF_BLOCKSIZE_MAX (128 * 1024) /**< maximum input size for a single block compressed with HUF_compress */ +size_t HUF_compressBound(size_t size); /**< maximum compressed size (worst case) */ + +/* Error Management */ +unsigned HUF_isError(size_t code); /**< tells if a return value is an error code */ + +/* *** Advanced function *** */ + +/** HUF_compress4X_wksp() : +* Same as HUF_compress2(), but uses externally allocated `workSpace`, which must be a table of >= 1024 unsigned */ +size_t HUF_compress4X_wksp(void *dst, size_t dstSize, const void *src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog, void *workSpace, + size_t wkspSize); /**< `workSpace` must be a table of at least HUF_COMPRESS_WORKSPACE_SIZE_U32 unsigned */ + +/* *** Dependencies *** */ +#include "mem.h" /* U32 */ + +/* *** Constants *** */ +#define HUF_TABLELOG_MAX 12 /* max configured tableLog (for static allocation); can be modified up to HUF_ABSOLUTEMAX_TABLELOG */ +#define HUF_TABLELOG_DEFAULT 11 /* tableLog by default, when not specified */ +#define HUF_SYMBOLVALUE_MAX 255 + +#define HUF_TABLELOG_ABSOLUTEMAX 15 /* absolute limit of HUF_MAX_TABLELOG. Beyond that value, code does not work */ +#if (HUF_TABLELOG_MAX > HUF_TABLELOG_ABSOLUTEMAX) +#error "HUF_TABLELOG_MAX is too large !" +#endif + +/* **************************************** +* Static allocation +******************************************/ +/* HUF buffer bounds */ +#define HUF_CTABLEBOUND 129 +#define HUF_BLOCKBOUND(size) (size + (size >> 8) + 8) /* only true if incompressible pre-filtered with fast heuristic */ +#define HUF_COMPRESSBOUND(size) (HUF_CTABLEBOUND + HUF_BLOCKBOUND(size)) /* Macro version, useful for static allocation */ + +/* static allocation of HUF's Compression Table */ +#define HUF_CREATE_STATIC_CTABLE(name, maxSymbolValue) \ + U32 name##hb[maxSymbolValue + 1]; \ + void *name##hv = &(name##hb); \ + HUF_CElt *name = (HUF_CElt *)(name##hv) /* no final ; */ + +/* static allocation of HUF's DTable */ +typedef U32 HUF_DTable; +#define HUF_DTABLE_SIZE(maxTableLog) (1 + (1 << (maxTableLog))) +#define HUF_CREATE_STATIC_DTABLEX2(DTable, maxTableLog) HUF_DTable DTable[HUF_DTABLE_SIZE((maxTableLog)-1)] = {((U32)((maxTableLog)-1) * 0x01000001)} +#define HUF_CREATE_STATIC_DTABLEX4(DTable, maxTableLog) HUF_DTable DTable[HUF_DTABLE_SIZE(maxTableLog)] = {((U32)(maxTableLog)*0x01000001)} + +/* The workspace must have alignment at least 4 and be at least this large */ +#define HUF_COMPRESS_WORKSPACE_SIZE (6 << 10) +#define HUF_COMPRESS_WORKSPACE_SIZE_U32 (HUF_COMPRESS_WORKSPACE_SIZE / sizeof(U32)) + +/* The workspace must have alignment at least 4 and be at least this large */ +#define HUF_DECOMPRESS_WORKSPACE_SIZE (3 << 10) +#define HUF_DECOMPRESS_WORKSPACE_SIZE_U32 (HUF_DECOMPRESS_WORKSPACE_SIZE / sizeof(U32)) + +/* **************************************** +* Advanced decompression functions +******************************************/ +size_t HUF_decompress4X_DCtx_wksp(HUF_DTable *dctx, void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize, void *workspace, size_t workspaceSize); /**< decodes RLE and uncompressed */ +size_t HUF_decompress4X_hufOnly_wksp(HUF_DTable *dctx, void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize, void *workspace, + size_t workspaceSize); /**< considers RLE and uncompressed as errors */ +size_t HUF_decompress4X2_DCtx_wksp(HUF_DTable *dctx, void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize, void *workspace, + size_t workspaceSize); /**< single-symbol decoder */ +size_t HUF_decompress4X4_DCtx_wksp(HUF_DTable *dctx, void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize, void *workspace, + size_t workspaceSize); /**< double-symbols decoder */ + +/* **************************************** +* HUF detailed API +******************************************/ +/*! +HUF_compress() does the following: +1. count symbol occurrence from source[] into table count[] using FSE_count() +2. (optional) refine tableLog using HUF_optimalTableLog() +3. build Huffman table from count using HUF_buildCTable() +4. save Huffman table to memory buffer using HUF_writeCTable_wksp() +5. encode the data stream using HUF_compress4X_usingCTable() + +The following API allows targeting specific sub-functions for advanced tasks. +For example, it's possible to compress several blocks using the same 'CTable', +or to save and regenerate 'CTable' using external methods. +*/ +/* FSE_count() : find it within "fse.h" */ +unsigned HUF_optimalTableLog(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue); +typedef struct HUF_CElt_s HUF_CElt; /* incomplete type */ +size_t HUF_writeCTable_wksp(void *dst, size_t maxDstSize, const HUF_CElt *CTable, unsigned maxSymbolValue, unsigned huffLog, void *workspace, size_t workspaceSize); +size_t HUF_compress4X_usingCTable(void *dst, size_t dstSize, const void *src, size_t srcSize, const HUF_CElt *CTable); + +typedef enum { + HUF_repeat_none, /**< Cannot use the previous table */ + HUF_repeat_check, /**< Can use the previous table but it must be checked. Note : The previous table must have been constructed by HUF_compress{1, + 4}X_repeat */ + HUF_repeat_valid /**< Can use the previous table and it is asumed to be valid */ +} HUF_repeat; +/** HUF_compress4X_repeat() : +* Same as HUF_compress4X_wksp(), but considers using hufTable if *repeat != HUF_repeat_none. +* If it uses hufTable it does not modify hufTable or repeat. +* If it doesn't, it sets *repeat = HUF_repeat_none, and it sets hufTable to the table used. +* If preferRepeat then the old table will always be used if valid. */ +size_t HUF_compress4X_repeat(void *dst, size_t dstSize, const void *src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog, void *workSpace, + size_t wkspSize, HUF_CElt *hufTable, HUF_repeat *repeat, + int preferRepeat); /**< `workSpace` must be a table of at least HUF_COMPRESS_WORKSPACE_SIZE_U32 unsigned */ + +/** HUF_buildCTable_wksp() : + * Same as HUF_buildCTable(), but using externally allocated scratch buffer. + * `workSpace` must be aligned on 4-bytes boundaries, and be at least as large as a table of 1024 unsigned. + */ +size_t HUF_buildCTable_wksp(HUF_CElt *tree, const U32 *count, U32 maxSymbolValue, U32 maxNbBits, void *workSpace, size_t wkspSize); + +/*! HUF_readStats() : + Read compact Huffman tree, saved by HUF_writeCTable(). + `huffWeight` is destination buffer. + @return : size read from `src` , or an error Code . + Note : Needed by HUF_readCTable() and HUF_readDTableXn() . */ +size_t HUF_readStats_wksp(BYTE *huffWeight, size_t hwSize, U32 *rankStats, U32 *nbSymbolsPtr, U32 *tableLogPtr, const void *src, size_t srcSize, + void *workspace, size_t workspaceSize); + +/** HUF_readCTable() : +* Loading a CTable saved with HUF_writeCTable() */ +size_t HUF_readCTable_wksp(HUF_CElt *CTable, unsigned maxSymbolValue, const void *src, size_t srcSize, void *workspace, size_t workspaceSize); + +/* +HUF_decompress() does the following: +1. select the decompression algorithm (X2, X4) based on pre-computed heuristics +2. build Huffman table from save, using HUF_readDTableXn() +3. decode 1 or 4 segments in parallel using HUF_decompressSXn_usingDTable +*/ + +/** HUF_selectDecoder() : +* Tells which decoder is likely to decode faster, +* based on a set of pre-determined metrics. +* @return : 0==HUF_decompress4X2, 1==HUF_decompress4X4 . +* Assumption : 0 < cSrcSize < dstSize <= 128 KB */ +U32 HUF_selectDecoder(size_t dstSize, size_t cSrcSize); + +size_t HUF_readDTableX2_wksp(HUF_DTable *DTable, const void *src, size_t srcSize, void *workspace, size_t workspaceSize); +size_t HUF_readDTableX4_wksp(HUF_DTable *DTable, const void *src, size_t srcSize, void *workspace, size_t workspaceSize); + +size_t HUF_decompress4X_usingDTable(void *dst, size_t maxDstSize, const void *cSrc, size_t cSrcSize, const HUF_DTable *DTable); +size_t HUF_decompress4X2_usingDTable(void *dst, size_t maxDstSize, const void *cSrc, size_t cSrcSize, const HUF_DTable *DTable); +size_t HUF_decompress4X4_usingDTable(void *dst, size_t maxDstSize, const void *cSrc, size_t cSrcSize, const HUF_DTable *DTable); + +/* single stream variants */ + +size_t HUF_compress1X_wksp(void *dst, size_t dstSize, const void *src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog, void *workSpace, + size_t wkspSize); /**< `workSpace` must be a table of at least HUF_COMPRESS_WORKSPACE_SIZE_U32 unsigned */ +size_t HUF_compress1X_usingCTable(void *dst, size_t dstSize, const void *src, size_t srcSize, const HUF_CElt *CTable); +/** HUF_compress1X_repeat() : +* Same as HUF_compress1X_wksp(), but considers using hufTable if *repeat != HUF_repeat_none. +* If it uses hufTable it does not modify hufTable or repeat. +* If it doesn't, it sets *repeat = HUF_repeat_none, and it sets hufTable to the table used. +* If preferRepeat then the old table will always be used if valid. */ +size_t HUF_compress1X_repeat(void *dst, size_t dstSize, const void *src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog, void *workSpace, + size_t wkspSize, HUF_CElt *hufTable, HUF_repeat *repeat, + int preferRepeat); /**< `workSpace` must be a table of at least HUF_COMPRESS_WORKSPACE_SIZE_U32 unsigned */ + +size_t HUF_decompress1X_DCtx_wksp(HUF_DTable *dctx, void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize, void *workspace, size_t workspaceSize); +size_t HUF_decompress1X2_DCtx_wksp(HUF_DTable *dctx, void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize, void *workspace, + size_t workspaceSize); /**< single-symbol decoder */ +size_t HUF_decompress1X4_DCtx_wksp(HUF_DTable *dctx, void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize, void *workspace, + size_t workspaceSize); /**< double-symbols decoder */ + +size_t HUF_decompress1X_usingDTable(void *dst, size_t maxDstSize, const void *cSrc, size_t cSrcSize, + const HUF_DTable *DTable); /**< automatic selection of sing or double symbol decoder, based on DTable */ +size_t HUF_decompress1X2_usingDTable(void *dst, size_t maxDstSize, const void *cSrc, size_t cSrcSize, const HUF_DTable *DTable); +size_t HUF_decompress1X4_usingDTable(void *dst, size_t maxDstSize, const void *cSrc, size_t cSrcSize, const HUF_DTable *DTable); + +#endif /* HUF_H_298734234 */ diff --git a/lib/zstd/huf_compress.c b/lib/zstd/huf_compress.c new file mode 100644 index 000000000000..40055a7016e6 --- /dev/null +++ b/lib/zstd/huf_compress.c @@ -0,0 +1,770 @@ +/* + * Huffman encoder, part of New Generation Entropy library + * Copyright (C) 2013-2016, Yann Collet. + * + * BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php) + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * This program is free software; you can redistribute it and/or modify it under + * the terms of the GNU General Public License version 2 as published by the + * Free Software Foundation. This program is dual-licensed; you may select + * either version 2 of the GNU General Public License ("GPL") or BSD license + * ("BSD"). + * + * You can contact the author at : + * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy + */ + +/* ************************************************************** +* Includes +****************************************************************/ +#include "bitstream.h" +#include "fse.h" /* header compression */ +#include "huf.h" +#include +#include /* memcpy, memset */ + +/* ************************************************************** +* Error Management +****************************************************************/ +#define HUF_STATIC_ASSERT(c) \ + { \ + enum { HUF_static_assert = 1 / (int)(!!(c)) }; \ + } /* use only *after* variable declarations */ +#define CHECK_V_F(e, f) \ + size_t const e = f; \ + if (ERR_isError(e)) \ + return f +#define CHECK_F(f) \ + { \ + CHECK_V_F(_var_err__, f); \ + } + +/* ************************************************************** +* Utils +****************************************************************/ +unsigned HUF_optimalTableLog(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue) +{ + return FSE_optimalTableLog_internal(maxTableLog, srcSize, maxSymbolValue, 1); +} + +/* ******************************************************* +* HUF : Huffman block compression +*********************************************************/ +/* HUF_compressWeights() : + * Same as FSE_compress(), but dedicated to huff0's weights compression. + * The use case needs much less stack memory. + * Note : all elements within weightTable are supposed to be <= HUF_TABLELOG_MAX. + */ +#define MAX_FSE_TABLELOG_FOR_HUFF_HEADER 6 +size_t HUF_compressWeights_wksp(void *dst, size_t dstSize, const void *weightTable, size_t wtSize, void *workspace, size_t workspaceSize) +{ + BYTE *const ostart = (BYTE *)dst; + BYTE *op = ostart; + BYTE *const oend = ostart + dstSize; + + U32 maxSymbolValue = HUF_TABLELOG_MAX; + U32 tableLog = MAX_FSE_TABLELOG_FOR_HUFF_HEADER; + + FSE_CTable *CTable; + U32 *count; + S16 *norm; + size_t spaceUsed32 = 0; + + HUF_STATIC_ASSERT(sizeof(FSE_CTable) == sizeof(U32)); + + CTable = (FSE_CTable *)((U32 *)workspace + spaceUsed32); + spaceUsed32 += FSE_CTABLE_SIZE_U32(MAX_FSE_TABLELOG_FOR_HUFF_HEADER, HUF_TABLELOG_MAX); + count = (U32 *)workspace + spaceUsed32; + spaceUsed32 += HUF_TABLELOG_MAX + 1; + norm = (S16 *)((U32 *)workspace + spaceUsed32); + spaceUsed32 += ALIGN(sizeof(S16) * (HUF_TABLELOG_MAX + 1), sizeof(U32)) >> 2; + + if ((spaceUsed32 << 2) > workspaceSize) + return ERROR(tableLog_tooLarge); + workspace = (U32 *)workspace + spaceUsed32; + workspaceSize -= (spaceUsed32 << 2); + + /* init conditions */ + if (wtSize <= 1) + return 0; /* Not compressible */ + + /* Scan input and build symbol stats */ + { + CHECK_V_F(maxCount, FSE_count_simple(count, &maxSymbolValue, weightTable, wtSize)); + if (maxCount == wtSize) + return 1; /* only a single symbol in src : rle */ + if (maxCount == 1) + return 0; /* each symbol present maximum once => not compressible */ + } + + tableLog = FSE_optimalTableLog(tableLog, wtSize, maxSymbolValue); + CHECK_F(FSE_normalizeCount(norm, tableLog, count, wtSize, maxSymbolValue)); + + /* Write table description header */ + { + CHECK_V_F(hSize, FSE_writeNCount(op, oend - op, norm, maxSymbolValue, tableLog)); + op += hSize; + } + + /* Compress */ + CHECK_F(FSE_buildCTable_wksp(CTable, norm, maxSymbolValue, tableLog, workspace, workspaceSize)); + { + CHECK_V_F(cSize, FSE_compress_usingCTable(op, oend - op, weightTable, wtSize, CTable)); + if (cSize == 0) + return 0; /* not enough space for compressed data */ + op += cSize; + } + + return op - ostart; +} + +struct HUF_CElt_s { + U16 val; + BYTE nbBits; +}; /* typedef'd to HUF_CElt within "huf.h" */ + +/*! HUF_writeCTable_wksp() : + `CTable` : Huffman tree to save, using huf representation. + @return : size of saved CTable */ +size_t HUF_writeCTable_wksp(void *dst, size_t maxDstSize, const HUF_CElt *CTable, U32 maxSymbolValue, U32 huffLog, void *workspace, size_t workspaceSize) +{ + BYTE *op = (BYTE *)dst; + U32 n; + + BYTE *bitsToWeight; + BYTE *huffWeight; + size_t spaceUsed32 = 0; + + bitsToWeight = (BYTE *)((U32 *)workspace + spaceUsed32); + spaceUsed32 += ALIGN(HUF_TABLELOG_MAX + 1, sizeof(U32)) >> 2; + huffWeight = (BYTE *)((U32 *)workspace + spaceUsed32); + spaceUsed32 += ALIGN(HUF_SYMBOLVALUE_MAX, sizeof(U32)) >> 2; + + if ((spaceUsed32 << 2) > workspaceSize) + return ERROR(tableLog_tooLarge); + workspace = (U32 *)workspace + spaceUsed32; + workspaceSize -= (spaceUsed32 << 2); + + /* check conditions */ + if (maxSymbolValue > HUF_SYMBOLVALUE_MAX) + return ERROR(maxSymbolValue_tooLarge); + + /* convert to weight */ + bitsToWeight[0] = 0; + for (n = 1; n < huffLog + 1; n++) + bitsToWeight[n] = (BYTE)(huffLog + 1 - n); + for (n = 0; n < maxSymbolValue; n++) + huffWeight[n] = bitsToWeight[CTable[n].nbBits]; + + /* attempt weights compression by FSE */ + { + CHECK_V_F(hSize, HUF_compressWeights_wksp(op + 1, maxDstSize - 1, huffWeight, maxSymbolValue, workspace, workspaceSize)); + if ((hSize > 1) & (hSize < maxSymbolValue / 2)) { /* FSE compressed */ + op[0] = (BYTE)hSize; + return hSize + 1; + } + } + + /* write raw values as 4-bits (max : 15) */ + if (maxSymbolValue > (256 - 128)) + return ERROR(GENERIC); /* should not happen : likely means source cannot be compressed */ + if (((maxSymbolValue + 1) / 2) + 1 > maxDstSize) + return ERROR(dstSize_tooSmall); /* not enough space within dst buffer */ + op[0] = (BYTE)(128 /*special case*/ + (maxSymbolValue - 1)); + huffWeight[maxSymbolValue] = 0; /* to be sure it doesn't cause msan issue in final combination */ + for (n = 0; n < maxSymbolValue; n += 2) + op[(n / 2) + 1] = (BYTE)((huffWeight[n] << 4) + huffWeight[n + 1]); + return ((maxSymbolValue + 1) / 2) + 1; +} + +size_t HUF_readCTable_wksp(HUF_CElt *CTable, U32 maxSymbolValue, const void *src, size_t srcSize, void *workspace, size_t workspaceSize) +{ + U32 *rankVal; + BYTE *huffWeight; + U32 tableLog = 0; + U32 nbSymbols = 0; + size_t readSize; + size_t spaceUsed32 = 0; + + rankVal = (U32 *)workspace + spaceUsed32; + spaceUsed32 += HUF_TABLELOG_ABSOLUTEMAX + 1; + huffWeight = (BYTE *)((U32 *)workspace + spaceUsed32); + spaceUsed32 += ALIGN(HUF_SYMBOLVALUE_MAX + 1, sizeof(U32)) >> 2; + + if ((spaceUsed32 << 2) > workspaceSize) + return ERROR(tableLog_tooLarge); + workspace = (U32 *)workspace + spaceUsed32; + workspaceSize -= (spaceUsed32 << 2); + + /* get symbol weights */ + readSize = HUF_readStats_wksp(huffWeight, HUF_SYMBOLVALUE_MAX + 1, rankVal, &nbSymbols, &tableLog, src, srcSize, workspace, workspaceSize); + if (ERR_isError(readSize)) + return readSize; + + /* check result */ + if (tableLog > HUF_TABLELOG_MAX) + return ERROR(tableLog_tooLarge); + if (nbSymbols > maxSymbolValue + 1) + return ERROR(maxSymbolValue_tooSmall); + + /* Prepare base value per rank */ + { + U32 n, nextRankStart = 0; + for (n = 1; n <= tableLog; n++) { + U32 curr = nextRankStart; + nextRankStart += (rankVal[n] << (n - 1)); + rankVal[n] = curr; + } + } + + /* fill nbBits */ + { + U32 n; + for (n = 0; n < nbSymbols; n++) { + const U32 w = huffWeight[n]; + CTable[n].nbBits = (BYTE)(tableLog + 1 - w); + } + } + + /* fill val */ + { + U16 nbPerRank[HUF_TABLELOG_MAX + 2] = {0}; /* support w=0=>n=tableLog+1 */ + U16 valPerRank[HUF_TABLELOG_MAX + 2] = {0}; + { + U32 n; + for (n = 0; n < nbSymbols; n++) + nbPerRank[CTable[n].nbBits]++; + } + /* determine stating value per rank */ + valPerRank[tableLog + 1] = 0; /* for w==0 */ + { + U16 min = 0; + U32 n; + for (n = tableLog; n > 0; n--) { /* start at n=tablelog <-> w=1 */ + valPerRank[n] = min; /* get starting value within each rank */ + min += nbPerRank[n]; + min >>= 1; + } + } + /* assign value within rank, symbol order */ + { + U32 n; + for (n = 0; n <= maxSymbolValue; n++) + CTable[n].val = valPerRank[CTable[n].nbBits]++; + } + } + + return readSize; +} + +typedef struct nodeElt_s { + U32 count; + U16 parent; + BYTE byte; + BYTE nbBits; +} nodeElt; + +static U32 HUF_setMaxHeight(nodeElt *huffNode, U32 lastNonNull, U32 maxNbBits) +{ + const U32 largestBits = huffNode[lastNonNull].nbBits; + if (largestBits <= maxNbBits) + return largestBits; /* early exit : no elt > maxNbBits */ + + /* there are several too large elements (at least >= 2) */ + { + int totalCost = 0; + const U32 baseCost = 1 << (largestBits - maxNbBits); + U32 n = lastNonNull; + + while (huffNode[n].nbBits > maxNbBits) { + totalCost += baseCost - (1 << (largestBits - huffNode[n].nbBits)); + huffNode[n].nbBits = (BYTE)maxNbBits; + n--; + } /* n stops at huffNode[n].nbBits <= maxNbBits */ + while (huffNode[n].nbBits == maxNbBits) + n--; /* n end at index of smallest symbol using < maxNbBits */ + + /* renorm totalCost */ + totalCost >>= (largestBits - maxNbBits); /* note : totalCost is necessarily a multiple of baseCost */ + + /* repay normalized cost */ + { + U32 const noSymbol = 0xF0F0F0F0; + U32 rankLast[HUF_TABLELOG_MAX + 2]; + int pos; + + /* Get pos of last (smallest) symbol per rank */ + memset(rankLast, 0xF0, sizeof(rankLast)); + { + U32 currNbBits = maxNbBits; + for (pos = n; pos >= 0; pos--) { + if (huffNode[pos].nbBits >= currNbBits) + continue; + currNbBits = huffNode[pos].nbBits; /* < maxNbBits */ + rankLast[maxNbBits - currNbBits] = pos; + } + } + + while (totalCost > 0) { + U32 nBitsToDecrease = BIT_highbit32(totalCost) + 1; + for (; nBitsToDecrease > 1; nBitsToDecrease--) { + U32 highPos = rankLast[nBitsToDecrease]; + U32 lowPos = rankLast[nBitsToDecrease - 1]; + if (highPos == noSymbol) + continue; + if (lowPos == noSymbol) + break; + { + U32 const highTotal = huffNode[highPos].count; + U32 const lowTotal = 2 * huffNode[lowPos].count; + if (highTotal <= lowTotal) + break; + } + } + /* only triggered when no more rank 1 symbol left => find closest one (note : there is necessarily at least one !) */ + /* HUF_MAX_TABLELOG test just to please gcc 5+; but it should not be necessary */ + while ((nBitsToDecrease <= HUF_TABLELOG_MAX) && (rankLast[nBitsToDecrease] == noSymbol)) + nBitsToDecrease++; + totalCost -= 1 << (nBitsToDecrease - 1); + if (rankLast[nBitsToDecrease - 1] == noSymbol) + rankLast[nBitsToDecrease - 1] = rankLast[nBitsToDecrease]; /* this rank is no longer empty */ + huffNode[rankLast[nBitsToDecrease]].nbBits++; + if (rankLast[nBitsToDecrease] == 0) /* special case, reached largest symbol */ + rankLast[nBitsToDecrease] = noSymbol; + else { + rankLast[nBitsToDecrease]--; + if (huffNode[rankLast[nBitsToDecrease]].nbBits != maxNbBits - nBitsToDecrease) + rankLast[nBitsToDecrease] = noSymbol; /* this rank is now empty */ + } + } /* while (totalCost > 0) */ + + while (totalCost < 0) { /* Sometimes, cost correction overshoot */ + if (rankLast[1] == noSymbol) { /* special case : no rank 1 symbol (using maxNbBits-1); let's create one from largest rank 0 + (using maxNbBits) */ + while (huffNode[n].nbBits == maxNbBits) + n--; + huffNode[n + 1].nbBits--; + rankLast[1] = n + 1; + totalCost++; + continue; + } + huffNode[rankLast[1] + 1].nbBits--; + rankLast[1]++; + totalCost++; + } + } + } /* there are several too large elements (at least >= 2) */ + + return maxNbBits; +} + +typedef struct { + U32 base; + U32 curr; +} rankPos; + +static void HUF_sort(nodeElt *huffNode, const U32 *count, U32 maxSymbolValue) +{ + rankPos rank[32]; + U32 n; + + memset(rank, 0, sizeof(rank)); + for (n = 0; n <= maxSymbolValue; n++) { + U32 r = BIT_highbit32(count[n] + 1); + rank[r].base++; + } + for (n = 30; n > 0; n--) + rank[n - 1].base += rank[n].base; + for (n = 0; n < 32; n++) + rank[n].curr = rank[n].base; + for (n = 0; n <= maxSymbolValue; n++) { + U32 const c = count[n]; + U32 const r = BIT_highbit32(c + 1) + 1; + U32 pos = rank[r].curr++; + while ((pos > rank[r].base) && (c > huffNode[pos - 1].count)) + huffNode[pos] = huffNode[pos - 1], pos--; + huffNode[pos].count = c; + huffNode[pos].byte = (BYTE)n; + } +} + +/** HUF_buildCTable_wksp() : + * Same as HUF_buildCTable(), but using externally allocated scratch buffer. + * `workSpace` must be aligned on 4-bytes boundaries, and be at least as large as a table of 1024 unsigned. + */ +#define STARTNODE (HUF_SYMBOLVALUE_MAX + 1) +typedef nodeElt huffNodeTable[2 * HUF_SYMBOLVALUE_MAX + 1 + 1]; +size_t HUF_buildCTable_wksp(HUF_CElt *tree, const U32 *count, U32 maxSymbolValue, U32 maxNbBits, void *workSpace, size_t wkspSize) +{ + nodeElt *const huffNode0 = (nodeElt *)workSpace; + nodeElt *const huffNode = huffNode0 + 1; + U32 n, nonNullRank; + int lowS, lowN; + U16 nodeNb = STARTNODE; + U32 nodeRoot; + + /* safety checks */ + if (wkspSize < sizeof(huffNodeTable)) + return ERROR(GENERIC); /* workSpace is not large enough */ + if (maxNbBits == 0) + maxNbBits = HUF_TABLELOG_DEFAULT; + if (maxSymbolValue > HUF_SYMBOLVALUE_MAX) + return ERROR(GENERIC); + memset(huffNode0, 0, sizeof(huffNodeTable)); + + /* sort, decreasing order */ + HUF_sort(huffNode, count, maxSymbolValue); + + /* init for parents */ + nonNullRank = maxSymbolValue; + while (huffNode[nonNullRank].count == 0) + nonNullRank--; + lowS = nonNullRank; + nodeRoot = nodeNb + lowS - 1; + lowN = nodeNb; + huffNode[nodeNb].count = huffNode[lowS].count + huffNode[lowS - 1].count; + huffNode[lowS].parent = huffNode[lowS - 1].parent = nodeNb; + nodeNb++; + lowS -= 2; + for (n = nodeNb; n <= nodeRoot; n++) + huffNode[n].count = (U32)(1U << 30); + huffNode0[0].count = (U32)(1U << 31); /* fake entry, strong barrier */ + + /* create parents */ + while (nodeNb <= nodeRoot) { + U32 n1 = (huffNode[lowS].count < huffNode[lowN].count) ? lowS-- : lowN++; + U32 n2 = (huffNode[lowS].count < huffNode[lowN].count) ? lowS-- : lowN++; + huffNode[nodeNb].count = huffNode[n1].count + huffNode[n2].count; + huffNode[n1].parent = huffNode[n2].parent = nodeNb; + nodeNb++; + } + + /* distribute weights (unlimited tree height) */ + huffNode[nodeRoot].nbBits = 0; + for (n = nodeRoot - 1; n >= STARTNODE; n--) + huffNode[n].nbBits = huffNode[huffNode[n].parent].nbBits + 1; + for (n = 0; n <= nonNullRank; n++) + huffNode[n].nbBits = huffNode[huffNode[n].parent].nbBits + 1; + + /* enforce maxTableLog */ + maxNbBits = HUF_setMaxHeight(huffNode, nonNullRank, maxNbBits); + + /* fill result into tree (val, nbBits) */ + { + U16 nbPerRank[HUF_TABLELOG_MAX + 1] = {0}; + U16 valPerRank[HUF_TABLELOG_MAX + 1] = {0}; + if (maxNbBits > HUF_TABLELOG_MAX) + return ERROR(GENERIC); /* check fit into table */ + for (n = 0; n <= nonNullRank; n++) + nbPerRank[huffNode[n].nbBits]++; + /* determine stating value per rank */ + { + U16 min = 0; + for (n = maxNbBits; n > 0; n--) { + valPerRank[n] = min; /* get starting value within each rank */ + min += nbPerRank[n]; + min >>= 1; + } + } + for (n = 0; n <= maxSymbolValue; n++) + tree[huffNode[n].byte].nbBits = huffNode[n].nbBits; /* push nbBits per symbol, symbol order */ + for (n = 0; n <= maxSymbolValue; n++) + tree[n].val = valPerRank[tree[n].nbBits]++; /* assign value within rank, symbol order */ + } + + return maxNbBits; +} + +static size_t HUF_estimateCompressedSize(HUF_CElt *CTable, const unsigned *count, unsigned maxSymbolValue) +{ + size_t nbBits = 0; + int s; + for (s = 0; s <= (int)maxSymbolValue; ++s) { + nbBits += CTable[s].nbBits * count[s]; + } + return nbBits >> 3; +} + +static int HUF_validateCTable(const HUF_CElt *CTable, const unsigned *count, unsigned maxSymbolValue) +{ + int bad = 0; + int s; + for (s = 0; s <= (int)maxSymbolValue; ++s) { + bad |= (count[s] != 0) & (CTable[s].nbBits == 0); + } + return !bad; +} + +static void HUF_encodeSymbol(BIT_CStream_t *bitCPtr, U32 symbol, const HUF_CElt *CTable) +{ + BIT_addBitsFast(bitCPtr, CTable[symbol].val, CTable[symbol].nbBits); +} + +size_t HUF_compressBound(size_t size) { return HUF_COMPRESSBOUND(size); } + +#define HUF_FLUSHBITS(s) BIT_flushBits(s) + +#define HUF_FLUSHBITS_1(stream) \ + if (sizeof((stream)->bitContainer) * 8 < HUF_TABLELOG_MAX * 2 + 7) \ + HUF_FLUSHBITS(stream) + +#define HUF_FLUSHBITS_2(stream) \ + if (sizeof((stream)->bitContainer) * 8 < HUF_TABLELOG_MAX * 4 + 7) \ + HUF_FLUSHBITS(stream) + +size_t HUF_compress1X_usingCTable(void *dst, size_t dstSize, const void *src, size_t srcSize, const HUF_CElt *CTable) +{ + const BYTE *ip = (const BYTE *)src; + BYTE *const ostart = (BYTE *)dst; + BYTE *const oend = ostart + dstSize; + BYTE *op = ostart; + size_t n; + BIT_CStream_t bitC; + + /* init */ + if (dstSize < 8) + return 0; /* not enough space to compress */ + { + size_t const initErr = BIT_initCStream(&bitC, op, oend - op); + if (HUF_isError(initErr)) + return 0; + } + + n = srcSize & ~3; /* join to mod 4 */ + switch (srcSize & 3) { + case 3: HUF_encodeSymbol(&bitC, ip[n + 2], CTable); HUF_FLUSHBITS_2(&bitC); + case 2: HUF_encodeSymbol(&bitC, ip[n + 1], CTable); HUF_FLUSHBITS_1(&bitC); + case 1: HUF_encodeSymbol(&bitC, ip[n + 0], CTable); HUF_FLUSHBITS(&bitC); + case 0: + default:; + } + + for (; n > 0; n -= 4) { /* note : n&3==0 at this stage */ + HUF_encodeSymbol(&bitC, ip[n - 1], CTable); + HUF_FLUSHBITS_1(&bitC); + HUF_encodeSymbol(&bitC, ip[n - 2], CTable); + HUF_FLUSHBITS_2(&bitC); + HUF_encodeSymbol(&bitC, ip[n - 3], CTable); + HUF_FLUSHBITS_1(&bitC); + HUF_encodeSymbol(&bitC, ip[n - 4], CTable); + HUF_FLUSHBITS(&bitC); + } + + return BIT_closeCStream(&bitC); +} + +size_t HUF_compress4X_usingCTable(void *dst, size_t dstSize, const void *src, size_t srcSize, const HUF_CElt *CTable) +{ + size_t const segmentSize = (srcSize + 3) / 4; /* first 3 segments */ + const BYTE *ip = (const BYTE *)src; + const BYTE *const iend = ip + srcSize; + BYTE *const ostart = (BYTE *)dst; + BYTE *const oend = ostart + dstSize; + BYTE *op = ostart; + + if (dstSize < 6 + 1 + 1 + 1 + 8) + return 0; /* minimum space to compress successfully */ + if (srcSize < 12) + return 0; /* no saving possible : too small input */ + op += 6; /* jumpTable */ + + { + CHECK_V_F(cSize, HUF_compress1X_usingCTable(op, oend - op, ip, segmentSize, CTable)); + if (cSize == 0) + return 0; + ZSTD_writeLE16(ostart, (U16)cSize); + op += cSize; + } + + ip += segmentSize; + { + CHECK_V_F(cSize, HUF_compress1X_usingCTable(op, oend - op, ip, segmentSize, CTable)); + if (cSize == 0) + return 0; + ZSTD_writeLE16(ostart + 2, (U16)cSize); + op += cSize; + } + + ip += segmentSize; + { + CHECK_V_F(cSize, HUF_compress1X_usingCTable(op, oend - op, ip, segmentSize, CTable)); + if (cSize == 0) + return 0; + ZSTD_writeLE16(ostart + 4, (U16)cSize); + op += cSize; + } + + ip += segmentSize; + { + CHECK_V_F(cSize, HUF_compress1X_usingCTable(op, oend - op, ip, iend - ip, CTable)); + if (cSize == 0) + return 0; + op += cSize; + } + + return op - ostart; +} + +static size_t HUF_compressCTable_internal(BYTE *const ostart, BYTE *op, BYTE *const oend, const void *src, size_t srcSize, unsigned singleStream, + const HUF_CElt *CTable) +{ + size_t const cSize = + singleStream ? HUF_compress1X_usingCTable(op, oend - op, src, srcSize, CTable) : HUF_compress4X_usingCTable(op, oend - op, src, srcSize, CTable); + if (HUF_isError(cSize)) { + return cSize; + } + if (cSize == 0) { + return 0; + } /* uncompressible */ + op += cSize; + /* check compressibility */ + if ((size_t)(op - ostart) >= srcSize - 1) { + return 0; + } + return op - ostart; +} + +/* `workSpace` must a table of at least 1024 unsigned */ +static size_t HUF_compress_internal(void *dst, size_t dstSize, const void *src, size_t srcSize, unsigned maxSymbolValue, unsigned huffLog, + unsigned singleStream, void *workSpace, size_t wkspSize, HUF_CElt *oldHufTable, HUF_repeat *repeat, int preferRepeat) +{ + BYTE *const ostart = (BYTE *)dst; + BYTE *const oend = ostart + dstSize; + BYTE *op = ostart; + + U32 *count; + size_t const countSize = sizeof(U32) * (HUF_SYMBOLVALUE_MAX + 1); + HUF_CElt *CTable; + size_t const CTableSize = sizeof(HUF_CElt) * (HUF_SYMBOLVALUE_MAX + 1); + + /* checks & inits */ + if (wkspSize < sizeof(huffNodeTable) + countSize + CTableSize) + return ERROR(GENERIC); + if (!srcSize) + return 0; /* Uncompressed (note : 1 means rle, so first byte must be correct) */ + if (!dstSize) + return 0; /* cannot fit within dst budget */ + if (srcSize > HUF_BLOCKSIZE_MAX) + return ERROR(srcSize_wrong); /* curr block size limit */ + if (huffLog > HUF_TABLELOG_MAX) + return ERROR(tableLog_tooLarge); + if (!maxSymbolValue) + maxSymbolValue = HUF_SYMBOLVALUE_MAX; + if (!huffLog) + huffLog = HUF_TABLELOG_DEFAULT; + + count = (U32 *)workSpace; + workSpace = (BYTE *)workSpace + countSize; + wkspSize -= countSize; + CTable = (HUF_CElt *)workSpace; + workSpace = (BYTE *)workSpace + CTableSize; + wkspSize -= CTableSize; + + /* Heuristic : If we don't need to check the validity of the old table use the old table for small inputs */ + if (preferRepeat && repeat && *repeat == HUF_repeat_valid) { + return HUF_compressCTable_internal(ostart, op, oend, src, srcSize, singleStream, oldHufTable); + } + + /* Scan input and build symbol stats */ + { + CHECK_V_F(largest, FSE_count_wksp(count, &maxSymbolValue, (const BYTE *)src, srcSize, (U32 *)workSpace)); + if (largest == srcSize) { + *ostart = ((const BYTE *)src)[0]; + return 1; + } /* single symbol, rle */ + if (largest <= (srcSize >> 7) + 1) + return 0; /* Fast heuristic : not compressible enough */ + } + + /* Check validity of previous table */ + if (repeat && *repeat == HUF_repeat_check && !HUF_validateCTable(oldHufTable, count, maxSymbolValue)) { + *repeat = HUF_repeat_none; + } + /* Heuristic : use existing table for small inputs */ + if (preferRepeat && repeat && *repeat != HUF_repeat_none) { + return HUF_compressCTable_internal(ostart, op, oend, src, srcSize, singleStream, oldHufTable); + } + + /* Build Huffman Tree */ + huffLog = HUF_optimalTableLog(huffLog, srcSize, maxSymbolValue); + { + CHECK_V_F(maxBits, HUF_buildCTable_wksp(CTable, count, maxSymbolValue, huffLog, workSpace, wkspSize)); + huffLog = (U32)maxBits; + /* Zero the unused symbols so we can check it for validity */ + memset(CTable + maxSymbolValue + 1, 0, CTableSize - (maxSymbolValue + 1) * sizeof(HUF_CElt)); + } + + /* Write table description header */ + { + CHECK_V_F(hSize, HUF_writeCTable_wksp(op, dstSize, CTable, maxSymbolValue, huffLog, workSpace, wkspSize)); + /* Check if using the previous table will be beneficial */ + if (repeat && *repeat != HUF_repeat_none) { + size_t const oldSize = HUF_estimateCompressedSize(oldHufTable, count, maxSymbolValue); + size_t const newSize = HUF_estimateCompressedSize(CTable, count, maxSymbolValue); + if (oldSize <= hSize + newSize || hSize + 12 >= srcSize) { + return HUF_compressCTable_internal(ostart, op, oend, src, srcSize, singleStream, oldHufTable); + } + } + /* Use the new table */ + if (hSize + 12ul >= srcSize) { + return 0; + } + op += hSize; + if (repeat) { + *repeat = HUF_repeat_none; + } + if (oldHufTable) { + memcpy(oldHufTable, CTable, CTableSize); + } /* Save the new table */ + } + return HUF_compressCTable_internal(ostart, op, oend, src, srcSize, singleStream, CTable); +} + +size_t HUF_compress1X_wksp(void *dst, size_t dstSize, const void *src, size_t srcSize, unsigned maxSymbolValue, unsigned huffLog, void *workSpace, + size_t wkspSize) +{ + return HUF_compress_internal(dst, dstSize, src, srcSize, maxSymbolValue, huffLog, 1 /* single stream */, workSpace, wkspSize, NULL, NULL, 0); +} + +size_t HUF_compress1X_repeat(void *dst, size_t dstSize, const void *src, size_t srcSize, unsigned maxSymbolValue, unsigned huffLog, void *workSpace, + size_t wkspSize, HUF_CElt *hufTable, HUF_repeat *repeat, int preferRepeat) +{ + return HUF_compress_internal(dst, dstSize, src, srcSize, maxSymbolValue, huffLog, 1 /* single stream */, workSpace, wkspSize, hufTable, repeat, + preferRepeat); +} + +size_t HUF_compress4X_wksp(void *dst, size_t dstSize, const void *src, size_t srcSize, unsigned maxSymbolValue, unsigned huffLog, void *workSpace, + size_t wkspSize) +{ + return HUF_compress_internal(dst, dstSize, src, srcSize, maxSymbolValue, huffLog, 0 /* 4 streams */, workSpace, wkspSize, NULL, NULL, 0); +} + +size_t HUF_compress4X_repeat(void *dst, size_t dstSize, const void *src, size_t srcSize, unsigned maxSymbolValue, unsigned huffLog, void *workSpace, + size_t wkspSize, HUF_CElt *hufTable, HUF_repeat *repeat, int preferRepeat) +{ + return HUF_compress_internal(dst, dstSize, src, srcSize, maxSymbolValue, huffLog, 0 /* 4 streams */, workSpace, wkspSize, hufTable, repeat, + preferRepeat); +} diff --git a/lib/zstd/huf_decompress.c b/lib/zstd/huf_decompress.c new file mode 100644 index 000000000000..6526482047dc --- /dev/null +++ b/lib/zstd/huf_decompress.c @@ -0,0 +1,960 @@ +/* + * Huffman decoder, part of New Generation Entropy library + * Copyright (C) 2013-2016, Yann Collet. + * + * BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php) + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * This program is free software; you can redistribute it and/or modify it under + * the terms of the GNU General Public License version 2 as published by the + * Free Software Foundation. This program is dual-licensed; you may select + * either version 2 of the GNU General Public License ("GPL") or BSD license + * ("BSD"). + * + * You can contact the author at : + * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy + */ + +/* ************************************************************** +* Compiler specifics +****************************************************************/ +#define FORCE_INLINE static __always_inline + +/* ************************************************************** +* Dependencies +****************************************************************/ +#include "bitstream.h" /* BIT_* */ +#include "fse.h" /* header compression */ +#include "huf.h" +#include +#include +#include /* memcpy, memset */ + +/* ************************************************************** +* Error Management +****************************************************************/ +#define HUF_STATIC_ASSERT(c) \ + { \ + enum { HUF_static_assert = 1 / (int)(!!(c)) }; \ + } /* use only *after* variable declarations */ + +/*-***************************/ +/* generic DTableDesc */ +/*-***************************/ + +typedef struct { + BYTE maxTableLog; + BYTE tableType; + BYTE tableLog; + BYTE reserved; +} DTableDesc; + +static DTableDesc HUF_getDTableDesc(const HUF_DTable *table) +{ + DTableDesc dtd; + memcpy(&dtd, table, sizeof(dtd)); + return dtd; +} + +/*-***************************/ +/* single-symbol decoding */ +/*-***************************/ + +typedef struct { + BYTE byte; + BYTE nbBits; +} HUF_DEltX2; /* single-symbol decoding */ + +size_t HUF_readDTableX2_wksp(HUF_DTable *DTable, const void *src, size_t srcSize, void *workspace, size_t workspaceSize) +{ + U32 tableLog = 0; + U32 nbSymbols = 0; + size_t iSize; + void *const dtPtr = DTable + 1; + HUF_DEltX2 *const dt = (HUF_DEltX2 *)dtPtr; + + U32 *rankVal; + BYTE *huffWeight; + size_t spaceUsed32 = 0; + + rankVal = (U32 *)workspace + spaceUsed32; + spaceUsed32 += HUF_TABLELOG_ABSOLUTEMAX + 1; + huffWeight = (BYTE *)((U32 *)workspace + spaceUsed32); + spaceUsed32 += ALIGN(HUF_SYMBOLVALUE_MAX + 1, sizeof(U32)) >> 2; + + if ((spaceUsed32 << 2) > workspaceSize) + return ERROR(tableLog_tooLarge); + workspace = (U32 *)workspace + spaceUsed32; + workspaceSize -= (spaceUsed32 << 2); + + HUF_STATIC_ASSERT(sizeof(DTableDesc) == sizeof(HUF_DTable)); + /* memset(huffWeight, 0, sizeof(huffWeight)); */ /* is not necessary, even though some analyzer complain ... */ + + iSize = HUF_readStats_wksp(huffWeight, HUF_SYMBOLVALUE_MAX + 1, rankVal, &nbSymbols, &tableLog, src, srcSize, workspace, workspaceSize); + if (HUF_isError(iSize)) + return iSize; + + /* Table header */ + { + DTableDesc dtd = HUF_getDTableDesc(DTable); + if (tableLog > (U32)(dtd.maxTableLog + 1)) + return ERROR(tableLog_tooLarge); /* DTable too small, Huffman tree cannot fit in */ + dtd.tableType = 0; + dtd.tableLog = (BYTE)tableLog; + memcpy(DTable, &dtd, sizeof(dtd)); + } + + /* Calculate starting value for each rank */ + { + U32 n, nextRankStart = 0; + for (n = 1; n < tableLog + 1; n++) { + U32 const curr = nextRankStart; + nextRankStart += (rankVal[n] << (n - 1)); + rankVal[n] = curr; + } + } + + /* fill DTable */ + { + U32 n; + for (n = 0; n < nbSymbols; n++) { + U32 const w = huffWeight[n]; + U32 const length = (1 << w) >> 1; + U32 u; + HUF_DEltX2 D; + D.byte = (BYTE)n; + D.nbBits = (BYTE)(tableLog + 1 - w); + for (u = rankVal[w]; u < rankVal[w] + length; u++) + dt[u] = D; + rankVal[w] += length; + } + } + + return iSize; +} + +static BYTE HUF_decodeSymbolX2(BIT_DStream_t *Dstream, const HUF_DEltX2 *dt, const U32 dtLog) +{ + size_t const val = BIT_lookBitsFast(Dstream, dtLog); /* note : dtLog >= 1 */ + BYTE const c = dt[val].byte; + BIT_skipBits(Dstream, dt[val].nbBits); + return c; +} + +#define HUF_DECODE_SYMBOLX2_0(ptr, DStreamPtr) *ptr++ = HUF_decodeSymbolX2(DStreamPtr, dt, dtLog) + +#define HUF_DECODE_SYMBOLX2_1(ptr, DStreamPtr) \ + if (ZSTD_64bits() || (HUF_TABLELOG_MAX <= 12)) \ + HUF_DECODE_SYMBOLX2_0(ptr, DStreamPtr) + +#define HUF_DECODE_SYMBOLX2_2(ptr, DStreamPtr) \ + if (ZSTD_64bits()) \ + HUF_DECODE_SYMBOLX2_0(ptr, DStreamPtr) + +FORCE_INLINE size_t HUF_decodeStreamX2(BYTE *p, BIT_DStream_t *const bitDPtr, BYTE *const pEnd, const HUF_DEltX2 *const dt, const U32 dtLog) +{ + BYTE *const pStart = p; + + /* up to 4 symbols at a time */ + while ((BIT_reloadDStream(bitDPtr) == BIT_DStream_unfinished) && (p <= pEnd - 4)) { + HUF_DECODE_SYMBOLX2_2(p, bitDPtr); + HUF_DECODE_SYMBOLX2_1(p, bitDPtr); + HUF_DECODE_SYMBOLX2_2(p, bitDPtr); + HUF_DECODE_SYMBOLX2_0(p, bitDPtr); + } + + /* closer to the end */ + while ((BIT_reloadDStream(bitDPtr) == BIT_DStream_unfinished) && (p < pEnd)) + HUF_DECODE_SYMBOLX2_0(p, bitDPtr); + + /* no more data to retrieve from bitstream, hence no need to reload */ + while (p < pEnd) + HUF_DECODE_SYMBOLX2_0(p, bitDPtr); + + return pEnd - pStart; +} + +static size_t HUF_decompress1X2_usingDTable_internal(void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize, const HUF_DTable *DTable) +{ + BYTE *op = (BYTE *)dst; + BYTE *const oend = op + dstSize; + const void *dtPtr = DTable + 1; + const HUF_DEltX2 *const dt = (const HUF_DEltX2 *)dtPtr; + BIT_DStream_t bitD; + DTableDesc const dtd = HUF_getDTableDesc(DTable); + U32 const dtLog = dtd.tableLog; + + { + size_t const errorCode = BIT_initDStream(&bitD, cSrc, cSrcSize); + if (HUF_isError(errorCode)) + return errorCode; + } + + HUF_decodeStreamX2(op, &bitD, oend, dt, dtLog); + + /* check */ + if (!BIT_endOfDStream(&bitD)) + return ERROR(corruption_detected); + + return dstSize; +} + +size_t HUF_decompress1X2_usingDTable(void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize, const HUF_DTable *DTable) +{ + DTableDesc dtd = HUF_getDTableDesc(DTable); + if (dtd.tableType != 0) + return ERROR(GENERIC); + return HUF_decompress1X2_usingDTable_internal(dst, dstSize, cSrc, cSrcSize, DTable); +} + +size_t HUF_decompress1X2_DCtx_wksp(HUF_DTable *DCtx, void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize, void *workspace, size_t workspaceSize) +{ + const BYTE *ip = (const BYTE *)cSrc; + + size_t const hSize = HUF_readDTableX2_wksp(DCtx, cSrc, cSrcSize, workspace, workspaceSize); + if (HUF_isError(hSize)) + return hSize; + if (hSize >= cSrcSize) + return ERROR(srcSize_wrong); + ip += hSize; + cSrcSize -= hSize; + + return HUF_decompress1X2_usingDTable_internal(dst, dstSize, ip, cSrcSize, DCtx); +} + +static size_t HUF_decompress4X2_usingDTable_internal(void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize, const HUF_DTable *DTable) +{ + /* Check */ + if (cSrcSize < 10) + return ERROR(corruption_detected); /* strict minimum : jump table + 1 byte per stream */ + + { + const BYTE *const istart = (const BYTE *)cSrc; + BYTE *const ostart = (BYTE *)dst; + BYTE *const oend = ostart + dstSize; + const void *const dtPtr = DTable + 1; + const HUF_DEltX2 *const dt = (const HUF_DEltX2 *)dtPtr; + + /* Init */ + BIT_DStream_t bitD1; + BIT_DStream_t bitD2; + BIT_DStream_t bitD3; + BIT_DStream_t bitD4; + size_t const length1 = ZSTD_readLE16(istart); + size_t const length2 = ZSTD_readLE16(istart + 2); + size_t const length3 = ZSTD_readLE16(istart + 4); + size_t const length4 = cSrcSize - (length1 + length2 + length3 + 6); + const BYTE *const istart1 = istart + 6; /* jumpTable */ + const BYTE *const istart2 = istart1 + length1; + const BYTE *const istart3 = istart2 + length2; + const BYTE *const istart4 = istart3 + length3; + const size_t segmentSize = (dstSize + 3) / 4; + BYTE *const opStart2 = ostart + segmentSize; + BYTE *const opStart3 = opStart2 + segmentSize; + BYTE *const opStart4 = opStart3 + segmentSize; + BYTE *op1 = ostart; + BYTE *op2 = opStart2; + BYTE *op3 = opStart3; + BYTE *op4 = opStart4; + U32 endSignal; + DTableDesc const dtd = HUF_getDTableDesc(DTable); + U32 const dtLog = dtd.tableLog; + + if (length4 > cSrcSize) + return ERROR(corruption_detected); /* overflow */ + { + size_t const errorCode = BIT_initDStream(&bitD1, istart1, length1); + if (HUF_isError(errorCode)) + return errorCode; + } + { + size_t const errorCode = BIT_initDStream(&bitD2, istart2, length2); + if (HUF_isError(errorCode)) + return errorCode; + } + { + size_t const errorCode = BIT_initDStream(&bitD3, istart3, length3); + if (HUF_isError(errorCode)) + return errorCode; + } + { + size_t const errorCode = BIT_initDStream(&bitD4, istart4, length4); + if (HUF_isError(errorCode)) + return errorCode; + } + + /* 16-32 symbols per loop (4-8 symbols per stream) */ + endSignal = BIT_reloadDStream(&bitD1) | BIT_reloadDStream(&bitD2) | BIT_reloadDStream(&bitD3) | BIT_reloadDStream(&bitD4); + for (; (endSignal == BIT_DStream_unfinished) && (op4 < (oend - 7));) { + HUF_DECODE_SYMBOLX2_2(op1, &bitD1); + HUF_DECODE_SYMBOLX2_2(op2, &bitD2); + HUF_DECODE_SYMBOLX2_2(op3, &bitD3); + HUF_DECODE_SYMBOLX2_2(op4, &bitD4); + HUF_DECODE_SYMBOLX2_1(op1, &bitD1); + HUF_DECODE_SYMBOLX2_1(op2, &bitD2); + HUF_DECODE_SYMBOLX2_1(op3, &bitD3); + HUF_DECODE_SYMBOLX2_1(op4, &bitD4); + HUF_DECODE_SYMBOLX2_2(op1, &bitD1); + HUF_DECODE_SYMBOLX2_2(op2, &bitD2); + HUF_DECODE_SYMBOLX2_2(op3, &bitD3); + HUF_DECODE_SYMBOLX2_2(op4, &bitD4); + HUF_DECODE_SYMBOLX2_0(op1, &bitD1); + HUF_DECODE_SYMBOLX2_0(op2, &bitD2); + HUF_DECODE_SYMBOLX2_0(op3, &bitD3); + HUF_DECODE_SYMBOLX2_0(op4, &bitD4); + endSignal = BIT_reloadDStream(&bitD1) | BIT_reloadDStream(&bitD2) | BIT_reloadDStream(&bitD3) | BIT_reloadDStream(&bitD4); + } + + /* check corruption */ + if (op1 > opStart2) + return ERROR(corruption_detected); + if (op2 > opStart3) + return ERROR(corruption_detected); + if (op3 > opStart4) + return ERROR(corruption_detected); + /* note : op4 supposed already verified within main loop */ + + /* finish bitStreams one by one */ + HUF_decodeStreamX2(op1, &bitD1, opStart2, dt, dtLog); + HUF_decodeStreamX2(op2, &bitD2, opStart3, dt, dtLog); + HUF_decodeStreamX2(op3, &bitD3, opStart4, dt, dtLog); + HUF_decodeStreamX2(op4, &bitD4, oend, dt, dtLog); + + /* check */ + endSignal = BIT_endOfDStream(&bitD1) & BIT_endOfDStream(&bitD2) & BIT_endOfDStream(&bitD3) & BIT_endOfDStream(&bitD4); + if (!endSignal) + return ERROR(corruption_detected); + + /* decoded size */ + return dstSize; + } +} + +size_t HUF_decompress4X2_usingDTable(void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize, const HUF_DTable *DTable) +{ + DTableDesc dtd = HUF_getDTableDesc(DTable); + if (dtd.tableType != 0) + return ERROR(GENERIC); + return HUF_decompress4X2_usingDTable_internal(dst, dstSize, cSrc, cSrcSize, DTable); +} + +size_t HUF_decompress4X2_DCtx_wksp(HUF_DTable *dctx, void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize, void *workspace, size_t workspaceSize) +{ + const BYTE *ip = (const BYTE *)cSrc; + + size_t const hSize = HUF_readDTableX2_wksp(dctx, cSrc, cSrcSize, workspace, workspaceSize); + if (HUF_isError(hSize)) + return hSize; + if (hSize >= cSrcSize) + return ERROR(srcSize_wrong); + ip += hSize; + cSrcSize -= hSize; + + return HUF_decompress4X2_usingDTable_internal(dst, dstSize, ip, cSrcSize, dctx); +} + +/* *************************/ +/* double-symbols decoding */ +/* *************************/ +typedef struct { + U16 sequence; + BYTE nbBits; + BYTE length; +} HUF_DEltX4; /* double-symbols decoding */ + +typedef struct { + BYTE symbol; + BYTE weight; +} sortedSymbol_t; + +/* HUF_fillDTableX4Level2() : + * `rankValOrigin` must be a table of at least (HUF_TABLELOG_MAX + 1) U32 */ +static void HUF_fillDTableX4Level2(HUF_DEltX4 *DTable, U32 sizeLog, const U32 consumed, const U32 *rankValOrigin, const int minWeight, + const sortedSymbol_t *sortedSymbols, const U32 sortedListSize, U32 nbBitsBaseline, U16 baseSeq) +{ + HUF_DEltX4 DElt; + U32 rankVal[HUF_TABLELOG_MAX + 1]; + + /* get pre-calculated rankVal */ + memcpy(rankVal, rankValOrigin, sizeof(rankVal)); + + /* fill skipped values */ + if (minWeight > 1) { + U32 i, skipSize = rankVal[minWeight]; + ZSTD_writeLE16(&(DElt.sequence), baseSeq); + DElt.nbBits = (BYTE)(consumed); + DElt.length = 1; + for (i = 0; i < skipSize; i++) + DTable[i] = DElt; + } + + /* fill DTable */ + { + U32 s; + for (s = 0; s < sortedListSize; s++) { /* note : sortedSymbols already skipped */ + const U32 symbol = sortedSymbols[s].symbol; + const U32 weight = sortedSymbols[s].weight; + const U32 nbBits = nbBitsBaseline - weight; + const U32 length = 1 << (sizeLog - nbBits); + const U32 start = rankVal[weight]; + U32 i = start; + const U32 end = start + length; + + ZSTD_writeLE16(&(DElt.sequence), (U16)(baseSeq + (symbol << 8))); + DElt.nbBits = (BYTE)(nbBits + consumed); + DElt.length = 2; + do { + DTable[i++] = DElt; + } while (i < end); /* since length >= 1 */ + + rankVal[weight] += length; + } + } +} + +typedef U32 rankVal_t[HUF_TABLELOG_MAX][HUF_TABLELOG_MAX + 1]; +typedef U32 rankValCol_t[HUF_TABLELOG_MAX + 1]; + +static void HUF_fillDTableX4(HUF_DEltX4 *DTable, const U32 targetLog, const sortedSymbol_t *sortedList, const U32 sortedListSize, const U32 *rankStart, + rankVal_t rankValOrigin, const U32 maxWeight, const U32 nbBitsBaseline) +{ + U32 rankVal[HUF_TABLELOG_MAX + 1]; + const int scaleLog = nbBitsBaseline - targetLog; /* note : targetLog >= srcLog, hence scaleLog <= 1 */ + const U32 minBits = nbBitsBaseline - maxWeight; + U32 s; + + memcpy(rankVal, rankValOrigin, sizeof(rankVal)); + + /* fill DTable */ + for (s = 0; s < sortedListSize; s++) { + const U16 symbol = sortedList[s].symbol; + const U32 weight = sortedList[s].weight; + const U32 nbBits = nbBitsBaseline - weight; + const U32 start = rankVal[weight]; + const U32 length = 1 << (targetLog - nbBits); + + if (targetLog - nbBits >= minBits) { /* enough room for a second symbol */ + U32 sortedRank; + int minWeight = nbBits + scaleLog; + if (minWeight < 1) + minWeight = 1; + sortedRank = rankStart[minWeight]; + HUF_fillDTableX4Level2(DTable + start, targetLog - nbBits, nbBits, rankValOrigin[nbBits], minWeight, sortedList + sortedRank, + sortedListSize - sortedRank, nbBitsBaseline, symbol); + } else { + HUF_DEltX4 DElt; + ZSTD_writeLE16(&(DElt.sequence), symbol); + DElt.nbBits = (BYTE)(nbBits); + DElt.length = 1; + { + U32 const end = start + length; + U32 u; + for (u = start; u < end; u++) + DTable[u] = DElt; + } + } + rankVal[weight] += length; + } +} + +size_t HUF_readDTableX4_wksp(HUF_DTable *DTable, const void *src, size_t srcSize, void *workspace, size_t workspaceSize) +{ + U32 tableLog, maxW, sizeOfSort, nbSymbols; + DTableDesc dtd = HUF_getDTableDesc(DTable); + U32 const maxTableLog = dtd.maxTableLog; + size_t iSize; + void *dtPtr = DTable + 1; /* force compiler to avoid strict-aliasing */ + HUF_DEltX4 *const dt = (HUF_DEltX4 *)dtPtr; + U32 *rankStart; + + rankValCol_t *rankVal; + U32 *rankStats; + U32 *rankStart0; + sortedSymbol_t *sortedSymbol; + BYTE *weightList; + size_t spaceUsed32 = 0; + + HUF_STATIC_ASSERT((sizeof(rankValCol_t) & 3) == 0); + + rankVal = (rankValCol_t *)((U32 *)workspace + spaceUsed32); + spaceUsed32 += (sizeof(rankValCol_t) * HUF_TABLELOG_MAX) >> 2; + rankStats = (U32 *)workspace + spaceUsed32; + spaceUsed32 += HUF_TABLELOG_MAX + 1; + rankStart0 = (U32 *)workspace + spaceUsed32; + spaceUsed32 += HUF_TABLELOG_MAX + 2; + sortedSymbol = (sortedSymbol_t *)((U32 *)workspace + spaceUsed32); + spaceUsed32 += ALIGN(sizeof(sortedSymbol_t) * (HUF_SYMBOLVALUE_MAX + 1), sizeof(U32)) >> 2; + weightList = (BYTE *)((U32 *)workspace + spaceUsed32); + spaceUsed32 += ALIGN(HUF_SYMBOLVALUE_MAX + 1, sizeof(U32)) >> 2; + + if ((spaceUsed32 << 2) > workspaceSize) + return ERROR(tableLog_tooLarge); + workspace = (U32 *)workspace + spaceUsed32; + workspaceSize -= (spaceUsed32 << 2); + + rankStart = rankStart0 + 1; + memset(rankStats, 0, sizeof(U32) * (2 * HUF_TABLELOG_MAX + 2 + 1)); + + HUF_STATIC_ASSERT(sizeof(HUF_DEltX4) == sizeof(HUF_DTable)); /* if compiler fails here, assertion is wrong */ + if (maxTableLog > HUF_TABLELOG_MAX) + return ERROR(tableLog_tooLarge); + /* memset(weightList, 0, sizeof(weightList)); */ /* is not necessary, even though some analyzer complain ... */ + + iSize = HUF_readStats_wksp(weightList, HUF_SYMBOLVALUE_MAX + 1, rankStats, &nbSymbols, &tableLog, src, srcSize, workspace, workspaceSize); + if (HUF_isError(iSize)) + return iSize; + + /* check result */ + if (tableLog > maxTableLog) + return ERROR(tableLog_tooLarge); /* DTable can't fit code depth */ + + /* find maxWeight */ + for (maxW = tableLog; rankStats[maxW] == 0; maxW--) { + } /* necessarily finds a solution before 0 */ + + /* Get start index of each weight */ + { + U32 w, nextRankStart = 0; + for (w = 1; w < maxW + 1; w++) { + U32 curr = nextRankStart; + nextRankStart += rankStats[w]; + rankStart[w] = curr; + } + rankStart[0] = nextRankStart; /* put all 0w symbols at the end of sorted list*/ + sizeOfSort = nextRankStart; + } + + /* sort symbols by weight */ + { + U32 s; + for (s = 0; s < nbSymbols; s++) { + U32 const w = weightList[s]; + U32 const r = rankStart[w]++; + sortedSymbol[r].symbol = (BYTE)s; + sortedSymbol[r].weight = (BYTE)w; + } + rankStart[0] = 0; /* forget 0w symbols; this is beginning of weight(1) */ + } + + /* Build rankVal */ + { + U32 *const rankVal0 = rankVal[0]; + { + int const rescale = (maxTableLog - tableLog) - 1; /* tableLog <= maxTableLog */ + U32 nextRankVal = 0; + U32 w; + for (w = 1; w < maxW + 1; w++) { + U32 curr = nextRankVal; + nextRankVal += rankStats[w] << (w + rescale); + rankVal0[w] = curr; + } + } + { + U32 const minBits = tableLog + 1 - maxW; + U32 consumed; + for (consumed = minBits; consumed < maxTableLog - minBits + 1; consumed++) { + U32 *const rankValPtr = rankVal[consumed]; + U32 w; + for (w = 1; w < maxW + 1; w++) { + rankValPtr[w] = rankVal0[w] >> consumed; + } + } + } + } + + HUF_fillDTableX4(dt, maxTableLog, sortedSymbol, sizeOfSort, rankStart0, rankVal, maxW, tableLog + 1); + + dtd.tableLog = (BYTE)maxTableLog; + dtd.tableType = 1; + memcpy(DTable, &dtd, sizeof(dtd)); + return iSize; +} + +static U32 HUF_decodeSymbolX4(void *op, BIT_DStream_t *DStream, const HUF_DEltX4 *dt, const U32 dtLog) +{ + size_t const val = BIT_lookBitsFast(DStream, dtLog); /* note : dtLog >= 1 */ + memcpy(op, dt + val, 2); + BIT_skipBits(DStream, dt[val].nbBits); + return dt[val].length; +} + +static U32 HUF_decodeLastSymbolX4(void *op, BIT_DStream_t *DStream, const HUF_DEltX4 *dt, const U32 dtLog) +{ + size_t const val = BIT_lookBitsFast(DStream, dtLog); /* note : dtLog >= 1 */ + memcpy(op, dt + val, 1); + if (dt[val].length == 1) + BIT_skipBits(DStream, dt[val].nbBits); + else { + if (DStream->bitsConsumed < (sizeof(DStream->bitContainer) * 8)) { + BIT_skipBits(DStream, dt[val].nbBits); + if (DStream->bitsConsumed > (sizeof(DStream->bitContainer) * 8)) + /* ugly hack; works only because it's the last symbol. Note : can't easily extract nbBits from just this symbol */ + DStream->bitsConsumed = (sizeof(DStream->bitContainer) * 8); + } + } + return 1; +} + +#define HUF_DECODE_SYMBOLX4_0(ptr, DStreamPtr) ptr += HUF_decodeSymbolX4(ptr, DStreamPtr, dt, dtLog) + +#define HUF_DECODE_SYMBOLX4_1(ptr, DStreamPtr) \ + if (ZSTD_64bits() || (HUF_TABLELOG_MAX <= 12)) \ + ptr += HUF_decodeSymbolX4(ptr, DStreamPtr, dt, dtLog) + +#define HUF_DECODE_SYMBOLX4_2(ptr, DStreamPtr) \ + if (ZSTD_64bits()) \ + ptr += HUF_decodeSymbolX4(ptr, DStreamPtr, dt, dtLog) + +FORCE_INLINE size_t HUF_decodeStreamX4(BYTE *p, BIT_DStream_t *bitDPtr, BYTE *const pEnd, const HUF_DEltX4 *const dt, const U32 dtLog) +{ + BYTE *const pStart = p; + + /* up to 8 symbols at a time */ + while ((BIT_reloadDStream(bitDPtr) == BIT_DStream_unfinished) & (p < pEnd - (sizeof(bitDPtr->bitContainer) - 1))) { + HUF_DECODE_SYMBOLX4_2(p, bitDPtr); + HUF_DECODE_SYMBOLX4_1(p, bitDPtr); + HUF_DECODE_SYMBOLX4_2(p, bitDPtr); + HUF_DECODE_SYMBOLX4_0(p, bitDPtr); + } + + /* closer to end : up to 2 symbols at a time */ + while ((BIT_reloadDStream(bitDPtr) == BIT_DStream_unfinished) & (p <= pEnd - 2)) + HUF_DECODE_SYMBOLX4_0(p, bitDPtr); + + while (p <= pEnd - 2) + HUF_DECODE_SYMBOLX4_0(p, bitDPtr); /* no need to reload : reached the end of DStream */ + + if (p < pEnd) + p += HUF_decodeLastSymbolX4(p, bitDPtr, dt, dtLog); + + return p - pStart; +} + +static size_t HUF_decompress1X4_usingDTable_internal(void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize, const HUF_DTable *DTable) +{ + BIT_DStream_t bitD; + + /* Init */ + { + size_t const errorCode = BIT_initDStream(&bitD, cSrc, cSrcSize); + if (HUF_isError(errorCode)) + return errorCode; + } + + /* decode */ + { + BYTE *const ostart = (BYTE *)dst; + BYTE *const oend = ostart + dstSize; + const void *const dtPtr = DTable + 1; /* force compiler to not use strict-aliasing */ + const HUF_DEltX4 *const dt = (const HUF_DEltX4 *)dtPtr; + DTableDesc const dtd = HUF_getDTableDesc(DTable); + HUF_decodeStreamX4(ostart, &bitD, oend, dt, dtd.tableLog); + } + + /* check */ + if (!BIT_endOfDStream(&bitD)) + return ERROR(corruption_detected); + + /* decoded size */ + return dstSize; +} + +size_t HUF_decompress1X4_usingDTable(void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize, const HUF_DTable *DTable) +{ + DTableDesc dtd = HUF_getDTableDesc(DTable); + if (dtd.tableType != 1) + return ERROR(GENERIC); + return HUF_decompress1X4_usingDTable_internal(dst, dstSize, cSrc, cSrcSize, DTable); +} + +size_t HUF_decompress1X4_DCtx_wksp(HUF_DTable *DCtx, void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize, void *workspace, size_t workspaceSize) +{ + const BYTE *ip = (const BYTE *)cSrc; + + size_t const hSize = HUF_readDTableX4_wksp(DCtx, cSrc, cSrcSize, workspace, workspaceSize); + if (HUF_isError(hSize)) + return hSize; + if (hSize >= cSrcSize) + return ERROR(srcSize_wrong); + ip += hSize; + cSrcSize -= hSize; + + return HUF_decompress1X4_usingDTable_internal(dst, dstSize, ip, cSrcSize, DCtx); +} + +static size_t HUF_decompress4X4_usingDTable_internal(void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize, const HUF_DTable *DTable) +{ + if (cSrcSize < 10) + return ERROR(corruption_detected); /* strict minimum : jump table + 1 byte per stream */ + + { + const BYTE *const istart = (const BYTE *)cSrc; + BYTE *const ostart = (BYTE *)dst; + BYTE *const oend = ostart + dstSize; + const void *const dtPtr = DTable + 1; + const HUF_DEltX4 *const dt = (const HUF_DEltX4 *)dtPtr; + + /* Init */ + BIT_DStream_t bitD1; + BIT_DStream_t bitD2; + BIT_DStream_t bitD3; + BIT_DStream_t bitD4; + size_t const length1 = ZSTD_readLE16(istart); + size_t const length2 = ZSTD_readLE16(istart + 2); + size_t const length3 = ZSTD_readLE16(istart + 4); + size_t const length4 = cSrcSize - (length1 + length2 + length3 + 6); + const BYTE *const istart1 = istart + 6; /* jumpTable */ + const BYTE *const istart2 = istart1 + length1; + const BYTE *const istart3 = istart2 + length2; + const BYTE *const istart4 = istart3 + length3; + size_t const segmentSize = (dstSize + 3) / 4; + BYTE *const opStart2 = ostart + segmentSize; + BYTE *const opStart3 = opStart2 + segmentSize; + BYTE *const opStart4 = opStart3 + segmentSize; + BYTE *op1 = ostart; + BYTE *op2 = opStart2; + BYTE *op3 = opStart3; + BYTE *op4 = opStart4; + U32 endSignal; + DTableDesc const dtd = HUF_getDTableDesc(DTable); + U32 const dtLog = dtd.tableLog; + + if (length4 > cSrcSize) + return ERROR(corruption_detected); /* overflow */ + { + size_t const errorCode = BIT_initDStream(&bitD1, istart1, length1); + if (HUF_isError(errorCode)) + return errorCode; + } + { + size_t const errorCode = BIT_initDStream(&bitD2, istart2, length2); + if (HUF_isError(errorCode)) + return errorCode; + } + { + size_t const errorCode = BIT_initDStream(&bitD3, istart3, length3); + if (HUF_isError(errorCode)) + return errorCode; + } + { + size_t const errorCode = BIT_initDStream(&bitD4, istart4, length4); + if (HUF_isError(errorCode)) + return errorCode; + } + + /* 16-32 symbols per loop (4-8 symbols per stream) */ + endSignal = BIT_reloadDStream(&bitD1) | BIT_reloadDStream(&bitD2) | BIT_reloadDStream(&bitD3) | BIT_reloadDStream(&bitD4); + for (; (endSignal == BIT_DStream_unfinished) & (op4 < (oend - (sizeof(bitD4.bitContainer) - 1)));) { + HUF_DECODE_SYMBOLX4_2(op1, &bitD1); + HUF_DECODE_SYMBOLX4_2(op2, &bitD2); + HUF_DECODE_SYMBOLX4_2(op3, &bitD3); + HUF_DECODE_SYMBOLX4_2(op4, &bitD4); + HUF_DECODE_SYMBOLX4_1(op1, &bitD1); + HUF_DECODE_SYMBOLX4_1(op2, &bitD2); + HUF_DECODE_SYMBOLX4_1(op3, &bitD3); + HUF_DECODE_SYMBOLX4_1(op4, &bitD4); + HUF_DECODE_SYMBOLX4_2(op1, &bitD1); + HUF_DECODE_SYMBOLX4_2(op2, &bitD2); + HUF_DECODE_SYMBOLX4_2(op3, &bitD3); + HUF_DECODE_SYMBOLX4_2(op4, &bitD4); + HUF_DECODE_SYMBOLX4_0(op1, &bitD1); + HUF_DECODE_SYMBOLX4_0(op2, &bitD2); + HUF_DECODE_SYMBOLX4_0(op3, &bitD3); + HUF_DECODE_SYMBOLX4_0(op4, &bitD4); + + endSignal = BIT_reloadDStream(&bitD1) | BIT_reloadDStream(&bitD2) | BIT_reloadDStream(&bitD3) | BIT_reloadDStream(&bitD4); + } + + /* check corruption */ + if (op1 > opStart2) + return ERROR(corruption_detected); + if (op2 > opStart3) + return ERROR(corruption_detected); + if (op3 > opStart4) + return ERROR(corruption_detected); + /* note : op4 already verified within main loop */ + + /* finish bitStreams one by one */ + HUF_decodeStreamX4(op1, &bitD1, opStart2, dt, dtLog); + HUF_decodeStreamX4(op2, &bitD2, opStart3, dt, dtLog); + HUF_decodeStreamX4(op3, &bitD3, opStart4, dt, dtLog); + HUF_decodeStreamX4(op4, &bitD4, oend, dt, dtLog); + + /* check */ + { + U32 const endCheck = BIT_endOfDStream(&bitD1) & BIT_endOfDStream(&bitD2) & BIT_endOfDStream(&bitD3) & BIT_endOfDStream(&bitD4); + if (!endCheck) + return ERROR(corruption_detected); + } + + /* decoded size */ + return dstSize; + } +} + +size_t HUF_decompress4X4_usingDTable(void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize, const HUF_DTable *DTable) +{ + DTableDesc dtd = HUF_getDTableDesc(DTable); + if (dtd.tableType != 1) + return ERROR(GENERIC); + return HUF_decompress4X4_usingDTable_internal(dst, dstSize, cSrc, cSrcSize, DTable); +} + +size_t HUF_decompress4X4_DCtx_wksp(HUF_DTable *dctx, void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize, void *workspace, size_t workspaceSize) +{ + const BYTE *ip = (const BYTE *)cSrc; + + size_t hSize = HUF_readDTableX4_wksp(dctx, cSrc, cSrcSize, workspace, workspaceSize); + if (HUF_isError(hSize)) + return hSize; + if (hSize >= cSrcSize) + return ERROR(srcSize_wrong); + ip += hSize; + cSrcSize -= hSize; + + return HUF_decompress4X4_usingDTable_internal(dst, dstSize, ip, cSrcSize, dctx); +} + +/* ********************************/ +/* Generic decompression selector */ +/* ********************************/ + +size_t HUF_decompress1X_usingDTable(void *dst, size_t maxDstSize, const void *cSrc, size_t cSrcSize, const HUF_DTable *DTable) +{ + DTableDesc const dtd = HUF_getDTableDesc(DTable); + return dtd.tableType ? HUF_decompress1X4_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable) + : HUF_decompress1X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable); +} + +size_t HUF_decompress4X_usingDTable(void *dst, size_t maxDstSize, const void *cSrc, size_t cSrcSize, const HUF_DTable *DTable) +{ + DTableDesc const dtd = HUF_getDTableDesc(DTable); + return dtd.tableType ? HUF_decompress4X4_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable) + : HUF_decompress4X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable); +} + +typedef struct { + U32 tableTime; + U32 decode256Time; +} algo_time_t; +static const algo_time_t algoTime[16 /* Quantization */][3 /* single, double, quad */] = { + /* single, double, quad */ + {{0, 0}, {1, 1}, {2, 2}}, /* Q==0 : impossible */ + {{0, 0}, {1, 1}, {2, 2}}, /* Q==1 : impossible */ + {{38, 130}, {1313, 74}, {2151, 38}}, /* Q == 2 : 12-18% */ + {{448, 128}, {1353, 74}, {2238, 41}}, /* Q == 3 : 18-25% */ + {{556, 128}, {1353, 74}, {2238, 47}}, /* Q == 4 : 25-32% */ + {{714, 128}, {1418, 74}, {2436, 53}}, /* Q == 5 : 32-38% */ + {{883, 128}, {1437, 74}, {2464, 61}}, /* Q == 6 : 38-44% */ + {{897, 128}, {1515, 75}, {2622, 68}}, /* Q == 7 : 44-50% */ + {{926, 128}, {1613, 75}, {2730, 75}}, /* Q == 8 : 50-56% */ + {{947, 128}, {1729, 77}, {3359, 77}}, /* Q == 9 : 56-62% */ + {{1107, 128}, {2083, 81}, {4006, 84}}, /* Q ==10 : 62-69% */ + {{1177, 128}, {2379, 87}, {4785, 88}}, /* Q ==11 : 69-75% */ + {{1242, 128}, {2415, 93}, {5155, 84}}, /* Q ==12 : 75-81% */ + {{1349, 128}, {2644, 106}, {5260, 106}}, /* Q ==13 : 81-87% */ + {{1455, 128}, {2422, 124}, {4174, 124}}, /* Q ==14 : 87-93% */ + {{722, 128}, {1891, 145}, {1936, 146}}, /* Q ==15 : 93-99% */ +}; + +/** HUF_selectDecoder() : +* Tells which decoder is likely to decode faster, +* based on a set of pre-determined metrics. +* @return : 0==HUF_decompress4X2, 1==HUF_decompress4X4 . +* Assumption : 0 < cSrcSize < dstSize <= 128 KB */ +U32 HUF_selectDecoder(size_t dstSize, size_t cSrcSize) +{ + /* decoder timing evaluation */ + U32 const Q = (U32)(cSrcSize * 16 / dstSize); /* Q < 16 since dstSize > cSrcSize */ + U32 const D256 = (U32)(dstSize >> 8); + U32 const DTime0 = algoTime[Q][0].tableTime + (algoTime[Q][0].decode256Time * D256); + U32 DTime1 = algoTime[Q][1].tableTime + (algoTime[Q][1].decode256Time * D256); + DTime1 += DTime1 >> 3; /* advantage to algorithm using less memory, for cache eviction */ + + return DTime1 < DTime0; +} + +typedef size_t (*decompressionAlgo)(void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize); + +size_t HUF_decompress4X_DCtx_wksp(HUF_DTable *dctx, void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize, void *workspace, size_t workspaceSize) +{ + /* validation checks */ + if (dstSize == 0) + return ERROR(dstSize_tooSmall); + if (cSrcSize > dstSize) + return ERROR(corruption_detected); /* invalid */ + if (cSrcSize == dstSize) { + memcpy(dst, cSrc, dstSize); + return dstSize; + } /* not compressed */ + if (cSrcSize == 1) { + memset(dst, *(const BYTE *)cSrc, dstSize); + return dstSize; + } /* RLE */ + + { + U32 const algoNb = HUF_selectDecoder(dstSize, cSrcSize); + return algoNb ? HUF_decompress4X4_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workspace, workspaceSize) + : HUF_decompress4X2_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workspace, workspaceSize); + } +} + +size_t HUF_decompress4X_hufOnly_wksp(HUF_DTable *dctx, void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize, void *workspace, size_t workspaceSize) +{ + /* validation checks */ + if (dstSize == 0) + return ERROR(dstSize_tooSmall); + if ((cSrcSize >= dstSize) || (cSrcSize <= 1)) + return ERROR(corruption_detected); /* invalid */ + + { + U32 const algoNb = HUF_selectDecoder(dstSize, cSrcSize); + return algoNb ? HUF_decompress4X4_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workspace, workspaceSize) + : HUF_decompress4X2_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workspace, workspaceSize); + } +} + +size_t HUF_decompress1X_DCtx_wksp(HUF_DTable *dctx, void *dst, size_t dstSize, const void *cSrc, size_t cSrcSize, void *workspace, size_t workspaceSize) +{ + /* validation checks */ + if (dstSize == 0) + return ERROR(dstSize_tooSmall); + if (cSrcSize > dstSize) + return ERROR(corruption_detected); /* invalid */ + if (cSrcSize == dstSize) { + memcpy(dst, cSrc, dstSize); + return dstSize; + } /* not compressed */ + if (cSrcSize == 1) { + memset(dst, *(const BYTE *)cSrc, dstSize); + return dstSize; + } /* RLE */ + + { + U32 const algoNb = HUF_selectDecoder(dstSize, cSrcSize); + return algoNb ? HUF_decompress1X4_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workspace, workspaceSize) + : HUF_decompress1X2_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workspace, workspaceSize); + } +} diff --git a/lib/zstd/mem.h b/lib/zstd/mem.h new file mode 100644 index 000000000000..3a0f34c8706c --- /dev/null +++ b/lib/zstd/mem.h @@ -0,0 +1,151 @@ +/** + * Copyright (c) 2016-present, Yann Collet, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of https://github.com/facebook/zstd. + * An additional grant of patent rights can be found in the PATENTS file in the + * same directory. + * + * This program is free software; you can redistribute it and/or modify it under + * the terms of the GNU General Public License version 2 as published by the + * Free Software Foundation. This program is dual-licensed; you may select + * either version 2 of the GNU General Public License ("GPL") or BSD license + * ("BSD"). + */ + +#ifndef MEM_H_MODULE +#define MEM_H_MODULE + +/*-**************************************** +* Dependencies +******************************************/ +#include +#include /* memcpy */ +#include /* size_t, ptrdiff_t */ + +/*-**************************************** +* Compiler specifics +******************************************/ +#define ZSTD_STATIC static __inline __attribute__((unused)) + +/*-************************************************************** +* Basic Types +*****************************************************************/ +typedef uint8_t BYTE; +typedef uint16_t U16; +typedef int16_t S16; +typedef uint32_t U32; +typedef int32_t S32; +typedef uint64_t U64; +typedef int64_t S64; +typedef ptrdiff_t iPtrDiff; +typedef uintptr_t uPtrDiff; + +/*-************************************************************** +* Memory I/O +*****************************************************************/ +ZSTD_STATIC unsigned ZSTD_32bits(void) { return sizeof(size_t) == 4; } +ZSTD_STATIC unsigned ZSTD_64bits(void) { return sizeof(size_t) == 8; } + +#if defined(__LITTLE_ENDIAN) +#define ZSTD_LITTLE_ENDIAN 1 +#else +#define ZSTD_LITTLE_ENDIAN 0 +#endif + +ZSTD_STATIC unsigned ZSTD_isLittleEndian(void) { return ZSTD_LITTLE_ENDIAN; } + +ZSTD_STATIC U16 ZSTD_read16(const void *memPtr) { return get_unaligned((const U16 *)memPtr); } + +ZSTD_STATIC U32 ZSTD_read32(const void *memPtr) { return get_unaligned((const U32 *)memPtr); } + +ZSTD_STATIC U64 ZSTD_read64(const void *memPtr) { return get_unaligned((const U64 *)memPtr); } + +ZSTD_STATIC size_t ZSTD_readST(const void *memPtr) { return get_unaligned((const size_t *)memPtr); } + +ZSTD_STATIC void ZSTD_write16(void *memPtr, U16 value) { put_unaligned(value, (U16 *)memPtr); } + +ZSTD_STATIC void ZSTD_write32(void *memPtr, U32 value) { put_unaligned(value, (U32 *)memPtr); } + +ZSTD_STATIC void ZSTD_write64(void *memPtr, U64 value) { put_unaligned(value, (U64 *)memPtr); } + +/*=== Little endian r/w ===*/ + +ZSTD_STATIC U16 ZSTD_readLE16(const void *memPtr) { return get_unaligned_le16(memPtr); } + +ZSTD_STATIC void ZSTD_writeLE16(void *memPtr, U16 val) { put_unaligned_le16(val, memPtr); } + +ZSTD_STATIC U32 ZSTD_readLE24(const void *memPtr) { return ZSTD_readLE16(memPtr) + (((const BYTE *)memPtr)[2] << 16); } + +ZSTD_STATIC void ZSTD_writeLE24(void *memPtr, U32 val) +{ + ZSTD_writeLE16(memPtr, (U16)val); + ((BYTE *)memPtr)[2] = (BYTE)(val >> 16); +} + +ZSTD_STATIC U32 ZSTD_readLE32(const void *memPtr) { return get_unaligned_le32(memPtr); } + +ZSTD_STATIC void ZSTD_writeLE32(void *memPtr, U32 val32) { put_unaligned_le32(val32, memPtr); } + +ZSTD_STATIC U64 ZSTD_readLE64(const void *memPtr) { return get_unaligned_le64(memPtr); } + +ZSTD_STATIC void ZSTD_writeLE64(void *memPtr, U64 val64) { put_unaligned_le64(val64, memPtr); } + +ZSTD_STATIC size_t ZSTD_readLEST(const void *memPtr) +{ + if (ZSTD_32bits()) + return (size_t)ZSTD_readLE32(memPtr); + else + return (size_t)ZSTD_readLE64(memPtr); +} + +ZSTD_STATIC void ZSTD_writeLEST(void *memPtr, size_t val) +{ + if (ZSTD_32bits()) + ZSTD_writeLE32(memPtr, (U32)val); + else + ZSTD_writeLE64(memPtr, (U64)val); +} + +/*=== Big endian r/w ===*/ + +ZSTD_STATIC U32 ZSTD_readBE32(const void *memPtr) { return get_unaligned_be32(memPtr); } + +ZSTD_STATIC void ZSTD_writeBE32(void *memPtr, U32 val32) { put_unaligned_be32(val32, memPtr); } + +ZSTD_STATIC U64 ZSTD_readBE64(const void *memPtr) { return get_unaligned_be64(memPtr); } + +ZSTD_STATIC void ZSTD_writeBE64(void *memPtr, U64 val64) { put_unaligned_be64(val64, memPtr); } + +ZSTD_STATIC size_t ZSTD_readBEST(const void *memPtr) +{ + if (ZSTD_32bits()) + return (size_t)ZSTD_readBE32(memPtr); + else + return (size_t)ZSTD_readBE64(memPtr); +} + +ZSTD_STATIC void ZSTD_writeBEST(void *memPtr, size_t val) +{ + if (ZSTD_32bits()) + ZSTD_writeBE32(memPtr, (U32)val); + else + ZSTD_writeBE64(memPtr, (U64)val); +} + +/* function safe only for comparisons */ +ZSTD_STATIC U32 ZSTD_readMINMATCH(const void *memPtr, U32 length) +{ + switch (length) { + default: + case 4: return ZSTD_read32(memPtr); + case 3: + if (ZSTD_isLittleEndian()) + return ZSTD_read32(memPtr) << 8; + else + return ZSTD_read32(memPtr) >> 8; + } +} + +#endif /* MEM_H_MODULE */ diff --git a/lib/zstd/zstd_common.c b/lib/zstd/zstd_common.c new file mode 100644 index 000000000000..a282624ee155 --- /dev/null +++ b/lib/zstd/zstd_common.c @@ -0,0 +1,75 @@ +/** + * Copyright (c) 2016-present, Yann Collet, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of https://github.com/facebook/zstd. + * An additional grant of patent rights can be found in the PATENTS file in the + * same directory. + * + * This program is free software; you can redistribute it and/or modify it under + * the terms of the GNU General Public License version 2 as published by the + * Free Software Foundation. This program is dual-licensed; you may select + * either version 2 of the GNU General Public License ("GPL") or BSD license + * ("BSD"). + */ + +/*-************************************* +* Dependencies +***************************************/ +#include "error_private.h" +#include "zstd_internal.h" /* declaration of ZSTD_isError, ZSTD_getErrorName, ZSTD_getErrorCode, ZSTD_getErrorString, ZSTD_versionNumber */ +#include + +/*=************************************************************** +* Custom allocator +****************************************************************/ + +#define stack_push(stack, size) \ + ({ \ + void *const ptr = ZSTD_PTR_ALIGN((stack)->ptr); \ + (stack)->ptr = (char *)ptr + (size); \ + (stack)->ptr <= (stack)->end ? ptr : NULL; \ + }) + +ZSTD_customMem ZSTD_initStack(void *workspace, size_t workspaceSize) +{ + ZSTD_customMem stackMem = {ZSTD_stackAlloc, ZSTD_stackFree, workspace}; + ZSTD_stack *stack = (ZSTD_stack *)workspace; + /* Verify preconditions */ + if (!workspace || workspaceSize < sizeof(ZSTD_stack) || workspace != ZSTD_PTR_ALIGN(workspace)) { + ZSTD_customMem error = {NULL, NULL, NULL}; + return error; + } + /* Initialize the stack */ + stack->ptr = workspace; + stack->end = (char *)workspace + workspaceSize; + stack_push(stack, sizeof(ZSTD_stack)); + return stackMem; +} + +void *ZSTD_stackAllocAll(void *opaque, size_t *size) +{ + ZSTD_stack *stack = (ZSTD_stack *)opaque; + *size = (BYTE const *)stack->end - (BYTE *)ZSTD_PTR_ALIGN(stack->ptr); + return stack_push(stack, *size); +} + +void *ZSTD_stackAlloc(void *opaque, size_t size) +{ + ZSTD_stack *stack = (ZSTD_stack *)opaque; + return stack_push(stack, size); +} +void ZSTD_stackFree(void *opaque, void *address) +{ + (void)opaque; + (void)address; +} + +void *ZSTD_malloc(size_t size, ZSTD_customMem customMem) { return customMem.customAlloc(customMem.opaque, size); } + +void ZSTD_free(void *ptr, ZSTD_customMem customMem) +{ + if (ptr != NULL) + customMem.customFree(customMem.opaque, ptr); +} diff --git a/lib/zstd/zstd_internal.h b/lib/zstd/zstd_internal.h new file mode 100644 index 000000000000..1a79fab9e13a --- /dev/null +++ b/lib/zstd/zstd_internal.h @@ -0,0 +1,263 @@ +/** + * Copyright (c) 2016-present, Yann Collet, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of https://github.com/facebook/zstd. + * An additional grant of patent rights can be found in the PATENTS file in the + * same directory. + * + * This program is free software; you can redistribute it and/or modify it under + * the terms of the GNU General Public License version 2 as published by the + * Free Software Foundation. This program is dual-licensed; you may select + * either version 2 of the GNU General Public License ("GPL") or BSD license + * ("BSD"). + */ + +#ifndef ZSTD_CCOMMON_H_MODULE +#define ZSTD_CCOMMON_H_MODULE + +/*-******************************************************* +* Compiler specifics +*********************************************************/ +#define FORCE_INLINE static __always_inline +#define FORCE_NOINLINE static noinline + +/*-************************************* +* Dependencies +***************************************/ +#include "error_private.h" +#include "mem.h" +#include +#include +#include +#include + +/*-************************************* +* shared macros +***************************************/ +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define CHECK_F(f) \ + { \ + size_t const errcod = f; \ + if (ERR_isError(errcod)) \ + return errcod; \ + } /* check and Forward error code */ +#define CHECK_E(f, e) \ + { \ + size_t const errcod = f; \ + if (ERR_isError(errcod)) \ + return ERROR(e); \ + } /* check and send Error code */ +#define ZSTD_STATIC_ASSERT(c) \ + { \ + enum { ZSTD_static_assert = 1 / (int)(!!(c)) }; \ + } + +/*-************************************* +* Common constants +***************************************/ +#define ZSTD_OPT_NUM (1 << 12) +#define ZSTD_DICT_MAGIC 0xEC30A437 /* v0.7+ */ + +#define ZSTD_REP_NUM 3 /* number of repcodes */ +#define ZSTD_REP_CHECK (ZSTD_REP_NUM) /* number of repcodes to check by the optimal parser */ +#define ZSTD_REP_MOVE (ZSTD_REP_NUM - 1) +#define ZSTD_REP_MOVE_OPT (ZSTD_REP_NUM) +static const U32 repStartValue[ZSTD_REP_NUM] = {1, 4, 8}; + +#define KB *(1 << 10) +#define MB *(1 << 20) +#define GB *(1U << 30) + +#define BIT7 128 +#define BIT6 64 +#define BIT5 32 +#define BIT4 16 +#define BIT1 2 +#define BIT0 1 + +#define ZSTD_WINDOWLOG_ABSOLUTEMIN 10 +static const size_t ZSTD_fcs_fieldSize[4] = {0, 2, 4, 8}; +static const size_t ZSTD_did_fieldSize[4] = {0, 1, 2, 4}; + +#define ZSTD_BLOCKHEADERSIZE 3 /* C standard doesn't allow `static const` variable to be init using another `static const` variable */ +static const size_t ZSTD_blockHeaderSize = ZSTD_BLOCKHEADERSIZE; +typedef enum { bt_raw, bt_rle, bt_compressed, bt_reserved } blockType_e; + +#define MIN_SEQUENCES_SIZE 1 /* nbSeq==0 */ +#define MIN_CBLOCK_SIZE (1 /*litCSize*/ + 1 /* RLE or RAW */ + MIN_SEQUENCES_SIZE /* nbSeq==0 */) /* for a non-null block */ + +#define HufLog 12 +typedef enum { set_basic, set_rle, set_compressed, set_repeat } symbolEncodingType_e; + +#define LONGNBSEQ 0x7F00 + +#define MINMATCH 3 +#define EQUAL_READ32 4 + +#define Litbits 8 +#define MaxLit ((1 << Litbits) - 1) +#define MaxML 52 +#define MaxLL 35 +#define MaxOff 28 +#define MaxSeq MAX(MaxLL, MaxML) /* Assumption : MaxOff < MaxLL,MaxML */ +#define MLFSELog 9 +#define LLFSELog 9 +#define OffFSELog 8 + +static const U32 LL_bits[MaxLL + 1] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; +static const S16 LL_defaultNorm[MaxLL + 1] = {4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1, -1, -1, -1, -1}; +#define LL_DEFAULTNORMLOG 6 /* for static allocation */ +static const U32 LL_defaultNormLog = LL_DEFAULTNORMLOG; + +static const U32 ML_bits[MaxML + 1] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; +static const S16 ML_defaultNorm[MaxML + 1] = {1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1}; +#define ML_DEFAULTNORMLOG 6 /* for static allocation */ +static const U32 ML_defaultNormLog = ML_DEFAULTNORMLOG; + +static const S16 OF_defaultNorm[MaxOff + 1] = {1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1}; +#define OF_DEFAULTNORMLOG 5 /* for static allocation */ +static const U32 OF_defaultNormLog = OF_DEFAULTNORMLOG; + +/*-******************************************* +* Shared functions to include for inlining +*********************************************/ +ZSTD_STATIC void ZSTD_copy8(void *dst, const void *src) { + memcpy(dst, src, 8); +} +/*! ZSTD_wildcopy() : +* custom version of memcpy(), can copy up to 7 bytes too many (8 bytes if length==0) */ +#define WILDCOPY_OVERLENGTH 8 +ZSTD_STATIC void ZSTD_wildcopy(void *dst, const void *src, ptrdiff_t length) +{ + const BYTE* ip = (const BYTE*)src; + BYTE* op = (BYTE*)dst; + BYTE* const oend = op + length; + /* Work around https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81388. + * Avoid the bad case where the loop only runs once by handling the + * special case separately. This doesn't trigger the bug because it + * doesn't involve pointer/integer overflow. + */ + if (length <= 8) + return ZSTD_copy8(dst, src); + do { + ZSTD_copy8(op, ip); + op += 8; + ip += 8; + } while (op < oend); +} + +/*-******************************************* +* Private interfaces +*********************************************/ +typedef struct ZSTD_stats_s ZSTD_stats_t; + +typedef struct { + U32 off; + U32 len; +} ZSTD_match_t; + +typedef struct { + U32 price; + U32 off; + U32 mlen; + U32 litlen; + U32 rep[ZSTD_REP_NUM]; +} ZSTD_optimal_t; + +typedef struct seqDef_s { + U32 offset; + U16 litLength; + U16 matchLength; +} seqDef; + +typedef struct { + seqDef *sequencesStart; + seqDef *sequences; + BYTE *litStart; + BYTE *lit; + BYTE *llCode; + BYTE *mlCode; + BYTE *ofCode; + U32 longLengthID; /* 0 == no longLength; 1 == Lit.longLength; 2 == Match.longLength; */ + U32 longLengthPos; + /* opt */ + ZSTD_optimal_t *priceTable; + ZSTD_match_t *matchTable; + U32 *matchLengthFreq; + U32 *litLengthFreq; + U32 *litFreq; + U32 *offCodeFreq; + U32 matchLengthSum; + U32 matchSum; + U32 litLengthSum; + U32 litSum; + U32 offCodeSum; + U32 log2matchLengthSum; + U32 log2matchSum; + U32 log2litLengthSum; + U32 log2litSum; + U32 log2offCodeSum; + U32 factor; + U32 staticPrices; + U32 cachedPrice; + U32 cachedLitLength; + const BYTE *cachedLiterals; +} seqStore_t; + +const seqStore_t *ZSTD_getSeqStore(const ZSTD_CCtx *ctx); +void ZSTD_seqToCodes(const seqStore_t *seqStorePtr); +int ZSTD_isSkipFrame(ZSTD_DCtx *dctx); + +/*= Custom memory allocation functions */ +typedef void *(*ZSTD_allocFunction)(void *opaque, size_t size); +typedef void (*ZSTD_freeFunction)(void *opaque, void *address); +typedef struct { + ZSTD_allocFunction customAlloc; + ZSTD_freeFunction customFree; + void *opaque; +} ZSTD_customMem; + +void *ZSTD_malloc(size_t size, ZSTD_customMem customMem); +void ZSTD_free(void *ptr, ZSTD_customMem customMem); + +/*====== stack allocation ======*/ + +typedef struct { + void *ptr; + const void *end; +} ZSTD_stack; + +#define ZSTD_ALIGN(x) ALIGN(x, sizeof(size_t)) +#define ZSTD_PTR_ALIGN(p) PTR_ALIGN(p, sizeof(size_t)) + +ZSTD_customMem ZSTD_initStack(void *workspace, size_t workspaceSize); + +void *ZSTD_stackAllocAll(void *opaque, size_t *size); +void *ZSTD_stackAlloc(void *opaque, size_t size); +void ZSTD_stackFree(void *opaque, void *address); + +/*====== common function ======*/ + +ZSTD_STATIC U32 ZSTD_highbit32(U32 val) { return 31 - __builtin_clz(val); } + +/* hidden functions */ + +/* ZSTD_invalidateRepCodes() : + * ensures next compression will not use repcodes from previous block. + * Note : only works with regular variant; + * do not use with extDict variant ! */ +void ZSTD_invalidateRepCodes(ZSTD_CCtx *cctx); + +size_t ZSTD_freeCCtx(ZSTD_CCtx *cctx); +size_t ZSTD_freeDCtx(ZSTD_DCtx *dctx); +size_t ZSTD_freeCDict(ZSTD_CDict *cdict); +size_t ZSTD_freeDDict(ZSTD_DDict *cdict); +size_t ZSTD_freeCStream(ZSTD_CStream *zcs); +size_t ZSTD_freeDStream(ZSTD_DStream *zds); + +#endif /* ZSTD_CCOMMON_H_MODULE */ diff --git a/lib/zstd/zstd_opt.h b/lib/zstd/zstd_opt.h new file mode 100644 index 000000000000..55e1b4cba808 --- /dev/null +++ b/lib/zstd/zstd_opt.h @@ -0,0 +1,1014 @@ +/** + * Copyright (c) 2016-present, Przemyslaw Skibinski, Yann Collet, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of https://github.com/facebook/zstd. + * An additional grant of patent rights can be found in the PATENTS file in the + * same directory. + * + * This program is free software; you can redistribute it and/or modify it under + * the terms of the GNU General Public License version 2 as published by the + * Free Software Foundation. This program is dual-licensed; you may select + * either version 2 of the GNU General Public License ("GPL") or BSD license + * ("BSD"). + */ + +/* Note : this file is intended to be included within zstd_compress.c */ + +#ifndef ZSTD_OPT_H_91842398743 +#define ZSTD_OPT_H_91842398743 + +#define ZSTD_LITFREQ_ADD 2 +#define ZSTD_FREQ_DIV 4 +#define ZSTD_MAX_PRICE (1 << 30) + +/*-************************************* +* Price functions for optimal parser +***************************************/ +FORCE_INLINE void ZSTD_setLog2Prices(seqStore_t *ssPtr) +{ + ssPtr->log2matchLengthSum = ZSTD_highbit32(ssPtr->matchLengthSum + 1); + ssPtr->log2litLengthSum = ZSTD_highbit32(ssPtr->litLengthSum + 1); + ssPtr->log2litSum = ZSTD_highbit32(ssPtr->litSum + 1); + ssPtr->log2offCodeSum = ZSTD_highbit32(ssPtr->offCodeSum + 1); + ssPtr->factor = 1 + ((ssPtr->litSum >> 5) / ssPtr->litLengthSum) + ((ssPtr->litSum << 1) / (ssPtr->litSum + ssPtr->matchSum)); +} + +ZSTD_STATIC void ZSTD_rescaleFreqs(seqStore_t *ssPtr, const BYTE *src, size_t srcSize) +{ + unsigned u; + + ssPtr->cachedLiterals = NULL; + ssPtr->cachedPrice = ssPtr->cachedLitLength = 0; + ssPtr->staticPrices = 0; + + if (ssPtr->litLengthSum == 0) { + if (srcSize <= 1024) + ssPtr->staticPrices = 1; + + for (u = 0; u <= MaxLit; u++) + ssPtr->litFreq[u] = 0; + for (u = 0; u < srcSize; u++) + ssPtr->litFreq[src[u]]++; + + ssPtr->litSum = 0; + ssPtr->litLengthSum = MaxLL + 1; + ssPtr->matchLengthSum = MaxML + 1; + ssPtr->offCodeSum = (MaxOff + 1); + ssPtr->matchSum = (ZSTD_LITFREQ_ADD << Litbits); + + for (u = 0; u <= MaxLit; u++) { + ssPtr->litFreq[u] = 1 + (ssPtr->litFreq[u] >> ZSTD_FREQ_DIV); + ssPtr->litSum += ssPtr->litFreq[u]; + } + for (u = 0; u <= MaxLL; u++) + ssPtr->litLengthFreq[u] = 1; + for (u = 0; u <= MaxML; u++) + ssPtr->matchLengthFreq[u] = 1; + for (u = 0; u <= MaxOff; u++) + ssPtr->offCodeFreq[u] = 1; + } else { + ssPtr->matchLengthSum = 0; + ssPtr->litLengthSum = 0; + ssPtr->offCodeSum = 0; + ssPtr->matchSum = 0; + ssPtr->litSum = 0; + + for (u = 0; u <= MaxLit; u++) { + ssPtr->litFreq[u] = 1 + (ssPtr->litFreq[u] >> (ZSTD_FREQ_DIV + 1)); + ssPtr->litSum += ssPtr->litFreq[u]; + } + for (u = 0; u <= MaxLL; u++) { + ssPtr->litLengthFreq[u] = 1 + (ssPtr->litLengthFreq[u] >> (ZSTD_FREQ_DIV + 1)); + ssPtr->litLengthSum += ssPtr->litLengthFreq[u]; + } + for (u = 0; u <= MaxML; u++) { + ssPtr->matchLengthFreq[u] = 1 + (ssPtr->matchLengthFreq[u] >> ZSTD_FREQ_DIV); + ssPtr->matchLengthSum += ssPtr->matchLengthFreq[u]; + ssPtr->matchSum += ssPtr->matchLengthFreq[u] * (u + 3); + } + ssPtr->matchSum *= ZSTD_LITFREQ_ADD; + for (u = 0; u <= MaxOff; u++) { + ssPtr->offCodeFreq[u] = 1 + (ssPtr->offCodeFreq[u] >> ZSTD_FREQ_DIV); + ssPtr->offCodeSum += ssPtr->offCodeFreq[u]; + } + } + + ZSTD_setLog2Prices(ssPtr); +} + +FORCE_INLINE U32 ZSTD_getLiteralPrice(seqStore_t *ssPtr, U32 litLength, const BYTE *literals) +{ + U32 price, u; + + if (ssPtr->staticPrices) + return ZSTD_highbit32((U32)litLength + 1) + (litLength * 6); + + if (litLength == 0) + return ssPtr->log2litLengthSum - ZSTD_highbit32(ssPtr->litLengthFreq[0] + 1); + + /* literals */ + if (ssPtr->cachedLiterals == literals) { + U32 const additional = litLength - ssPtr->cachedLitLength; + const BYTE *literals2 = ssPtr->cachedLiterals + ssPtr->cachedLitLength; + price = ssPtr->cachedPrice + additional * ssPtr->log2litSum; + for (u = 0; u < additional; u++) + price -= ZSTD_highbit32(ssPtr->litFreq[literals2[u]] + 1); + ssPtr->cachedPrice = price; + ssPtr->cachedLitLength = litLength; + } else { + price = litLength * ssPtr->log2litSum; + for (u = 0; u < litLength; u++) + price -= ZSTD_highbit32(ssPtr->litFreq[literals[u]] + 1); + + if (litLength >= 12) { + ssPtr->cachedLiterals = literals; + ssPtr->cachedPrice = price; + ssPtr->cachedLitLength = litLength; + } + } + + /* literal Length */ + { + const BYTE LL_deltaCode = 19; + const BYTE llCode = (litLength > 63) ? (BYTE)ZSTD_highbit32(litLength) + LL_deltaCode : LL_Code[litLength]; + price += LL_bits[llCode] + ssPtr->log2litLengthSum - ZSTD_highbit32(ssPtr->litLengthFreq[llCode] + 1); + } + + return price; +} + +FORCE_INLINE U32 ZSTD_getPrice(seqStore_t *seqStorePtr, U32 litLength, const BYTE *literals, U32 offset, U32 matchLength, const int ultra) +{ + /* offset */ + U32 price; + BYTE const offCode = (BYTE)ZSTD_highbit32(offset + 1); + + if (seqStorePtr->staticPrices) + return ZSTD_getLiteralPrice(seqStorePtr, litLength, literals) + ZSTD_highbit32((U32)matchLength + 1) + 16 + offCode; + + price = offCode + seqStorePtr->log2offCodeSum - ZSTD_highbit32(seqStorePtr->offCodeFreq[offCode] + 1); + if (!ultra && offCode >= 20) + price += (offCode - 19) * 2; + + /* match Length */ + { + const BYTE ML_deltaCode = 36; + const BYTE mlCode = (matchLength > 127) ? (BYTE)ZSTD_highbit32(matchLength) + ML_deltaCode : ML_Code[matchLength]; + price += ML_bits[mlCode] + seqStorePtr->log2matchLengthSum - ZSTD_highbit32(seqStorePtr->matchLengthFreq[mlCode] + 1); + } + + return price + ZSTD_getLiteralPrice(seqStorePtr, litLength, literals) + seqStorePtr->factor; +} + +ZSTD_STATIC void ZSTD_updatePrice(seqStore_t *seqStorePtr, U32 litLength, const BYTE *literals, U32 offset, U32 matchLength) +{ + U32 u; + + /* literals */ + seqStorePtr->litSum += litLength * ZSTD_LITFREQ_ADD; + for (u = 0; u < litLength; u++) + seqStorePtr->litFreq[literals[u]] += ZSTD_LITFREQ_ADD; + + /* literal Length */ + { + const BYTE LL_deltaCode = 19; + const BYTE llCode = (litLength > 63) ? (BYTE)ZSTD_highbit32(litLength) + LL_deltaCode : LL_Code[litLength]; + seqStorePtr->litLengthFreq[llCode]++; + seqStorePtr->litLengthSum++; + } + + /* match offset */ + { + BYTE const offCode = (BYTE)ZSTD_highbit32(offset + 1); + seqStorePtr->offCodeSum++; + seqStorePtr->offCodeFreq[offCode]++; + } + + /* match Length */ + { + const BYTE ML_deltaCode = 36; + const BYTE mlCode = (matchLength > 127) ? (BYTE)ZSTD_highbit32(matchLength) + ML_deltaCode : ML_Code[matchLength]; + seqStorePtr->matchLengthFreq[mlCode]++; + seqStorePtr->matchLengthSum++; + } + + ZSTD_setLog2Prices(seqStorePtr); +} + +#define SET_PRICE(pos, mlen_, offset_, litlen_, price_) \ + { \ + while (last_pos < pos) { \ + opt[last_pos + 1].price = ZSTD_MAX_PRICE; \ + last_pos++; \ + } \ + opt[pos].mlen = mlen_; \ + opt[pos].off = offset_; \ + opt[pos].litlen = litlen_; \ + opt[pos].price = price_; \ + } + +/* Update hashTable3 up to ip (excluded) + Assumption : always within prefix (i.e. not within extDict) */ +FORCE_INLINE +U32 ZSTD_insertAndFindFirstIndexHash3(ZSTD_CCtx *zc, const BYTE *ip) +{ + U32 *const hashTable3 = zc->hashTable3; + U32 const hashLog3 = zc->hashLog3; + const BYTE *const base = zc->base; + U32 idx = zc->nextToUpdate3; + const U32 target = zc->nextToUpdate3 = (U32)(ip - base); + const size_t hash3 = ZSTD_hash3Ptr(ip, hashLog3); + + while (idx < target) { + hashTable3[ZSTD_hash3Ptr(base + idx, hashLog3)] = idx; + idx++; + } + + return hashTable3[hash3]; +} + +/*-************************************* +* Binary Tree search +***************************************/ +static U32 ZSTD_insertBtAndGetAllMatches(ZSTD_CCtx *zc, const BYTE *const ip, const BYTE *const iLimit, U32 nbCompares, const U32 mls, U32 extDict, + ZSTD_match_t *matches, const U32 minMatchLen) +{ + const BYTE *const base = zc->base; + const U32 curr = (U32)(ip - base); + const U32 hashLog = zc->params.cParams.hashLog; + const size_t h = ZSTD_hashPtr(ip, hashLog, mls); + U32 *const hashTable = zc->hashTable; + U32 matchIndex = hashTable[h]; + U32 *const bt = zc->chainTable; + const U32 btLog = zc->params.cParams.chainLog - 1; + const U32 btMask = (1U << btLog) - 1; + size_t commonLengthSmaller = 0, commonLengthLarger = 0; + const BYTE *const dictBase = zc->dictBase; + const U32 dictLimit = zc->dictLimit; + const BYTE *const dictEnd = dictBase + dictLimit; + const BYTE *const prefixStart = base + dictLimit; + const U32 btLow = btMask >= curr ? 0 : curr - btMask; + const U32 windowLow = zc->lowLimit; + U32 *smallerPtr = bt + 2 * (curr & btMask); + U32 *largerPtr = bt + 2 * (curr & btMask) + 1; + U32 matchEndIdx = curr + 8; + U32 dummy32; /* to be nullified at the end */ + U32 mnum = 0; + + const U32 minMatch = (mls == 3) ? 3 : 4; + size_t bestLength = minMatchLen - 1; + + if (minMatch == 3) { /* HC3 match finder */ + U32 const matchIndex3 = ZSTD_insertAndFindFirstIndexHash3(zc, ip); + if (matchIndex3 > windowLow && (curr - matchIndex3 < (1 << 18))) { + const BYTE *match; + size_t currMl = 0; + if ((!extDict) || matchIndex3 >= dictLimit) { + match = base + matchIndex3; + if (match[bestLength] == ip[bestLength]) + currMl = ZSTD_count(ip, match, iLimit); + } else { + match = dictBase + matchIndex3; + if (ZSTD_readMINMATCH(match, MINMATCH) == + ZSTD_readMINMATCH(ip, MINMATCH)) /* assumption : matchIndex3 <= dictLimit-4 (by table construction) */ + currMl = ZSTD_count_2segments(ip + MINMATCH, match + MINMATCH, iLimit, dictEnd, prefixStart) + MINMATCH; + } + + /* save best solution */ + if (currMl > bestLength) { + bestLength = currMl; + matches[mnum].off = ZSTD_REP_MOVE_OPT + curr - matchIndex3; + matches[mnum].len = (U32)currMl; + mnum++; + if (currMl > ZSTD_OPT_NUM) + goto update; + if (ip + currMl == iLimit) + goto update; /* best possible, and avoid read overflow*/ + } + } + } + + hashTable[h] = curr; /* Update Hash Table */ + + while (nbCompares-- && (matchIndex > windowLow)) { + U32 *nextPtr = bt + 2 * (matchIndex & btMask); + size_t matchLength = MIN(commonLengthSmaller, commonLengthLarger); /* guaranteed minimum nb of common bytes */ + const BYTE *match; + + if ((!extDict) || (matchIndex + matchLength >= dictLimit)) { + match = base + matchIndex; + if (match[matchLength] == ip[matchLength]) { + matchLength += ZSTD_count(ip + matchLength + 1, match + matchLength + 1, iLimit) + 1; + } + } else { + match = dictBase + matchIndex; + matchLength += ZSTD_count_2segments(ip + matchLength, match + matchLength, iLimit, dictEnd, prefixStart); + if (matchIndex + matchLength >= dictLimit) + match = base + matchIndex; /* to prepare for next usage of match[matchLength] */ + } + + if (matchLength > bestLength) { + if (matchLength > matchEndIdx - matchIndex) + matchEndIdx = matchIndex + (U32)matchLength; + bestLength = matchLength; + matches[mnum].off = ZSTD_REP_MOVE_OPT + curr - matchIndex; + matches[mnum].len = (U32)matchLength; + mnum++; + if (matchLength > ZSTD_OPT_NUM) + break; + if (ip + matchLength == iLimit) /* equal : no way to know if inf or sup */ + break; /* drop, to guarantee consistency (miss a little bit of compression) */ + } + + if (match[matchLength] < ip[matchLength]) { + /* match is smaller than curr */ + *smallerPtr = matchIndex; /* update smaller idx */ + commonLengthSmaller = matchLength; /* all smaller will now have at least this guaranteed common length */ + if (matchIndex <= btLow) { + smallerPtr = &dummy32; + break; + } /* beyond tree size, stop the search */ + smallerPtr = nextPtr + 1; /* new "smaller" => larger of match */ + matchIndex = nextPtr[1]; /* new matchIndex larger than previous (closer to curr) */ + } else { + /* match is larger than curr */ + *largerPtr = matchIndex; + commonLengthLarger = matchLength; + if (matchIndex <= btLow) { + largerPtr = &dummy32; + break; + } /* beyond tree size, stop the search */ + largerPtr = nextPtr; + matchIndex = nextPtr[0]; + } + } + + *smallerPtr = *largerPtr = 0; + +update: + zc->nextToUpdate = (matchEndIdx > curr + 8) ? matchEndIdx - 8 : curr + 1; + return mnum; +} + +/** Tree updater, providing best match */ +static U32 ZSTD_BtGetAllMatches(ZSTD_CCtx *zc, const BYTE *const ip, const BYTE *const iLimit, const U32 maxNbAttempts, const U32 mls, ZSTD_match_t *matches, + const U32 minMatchLen) +{ + if (ip < zc->base + zc->nextToUpdate) + return 0; /* skipped area */ + ZSTD_updateTree(zc, ip, iLimit, maxNbAttempts, mls); + return ZSTD_insertBtAndGetAllMatches(zc, ip, iLimit, maxNbAttempts, mls, 0, matches, minMatchLen); +} + +static U32 ZSTD_BtGetAllMatches_selectMLS(ZSTD_CCtx *zc, /* Index table will be updated */ + const BYTE *ip, const BYTE *const iHighLimit, const U32 maxNbAttempts, const U32 matchLengthSearch, + ZSTD_match_t *matches, const U32 minMatchLen) +{ + switch (matchLengthSearch) { + case 3: return ZSTD_BtGetAllMatches(zc, ip, iHighLimit, maxNbAttempts, 3, matches, minMatchLen); + default: + case 4: return ZSTD_BtGetAllMatches(zc, ip, iHighLimit, maxNbAttempts, 4, matches, minMatchLen); + case 5: return ZSTD_BtGetAllMatches(zc, ip, iHighLimit, maxNbAttempts, 5, matches, minMatchLen); + case 7: + case 6: return ZSTD_BtGetAllMatches(zc, ip, iHighLimit, maxNbAttempts, 6, matches, minMatchLen); + } +} + +/** Tree updater, providing best match */ +static U32 ZSTD_BtGetAllMatches_extDict(ZSTD_CCtx *zc, const BYTE *const ip, const BYTE *const iLimit, const U32 maxNbAttempts, const U32 mls, + ZSTD_match_t *matches, const U32 minMatchLen) +{ + if (ip < zc->base + zc->nextToUpdate) + return 0; /* skipped area */ + ZSTD_updateTree_extDict(zc, ip, iLimit, maxNbAttempts, mls); + return ZSTD_insertBtAndGetAllMatches(zc, ip, iLimit, maxNbAttempts, mls, 1, matches, minMatchLen); +} + +static U32 ZSTD_BtGetAllMatches_selectMLS_extDict(ZSTD_CCtx *zc, /* Index table will be updated */ + const BYTE *ip, const BYTE *const iHighLimit, const U32 maxNbAttempts, const U32 matchLengthSearch, + ZSTD_match_t *matches, const U32 minMatchLen) +{ + switch (matchLengthSearch) { + case 3: return ZSTD_BtGetAllMatches_extDict(zc, ip, iHighLimit, maxNbAttempts, 3, matches, minMatchLen); + default: + case 4: return ZSTD_BtGetAllMatches_extDict(zc, ip, iHighLimit, maxNbAttempts, 4, matches, minMatchLen); + case 5: return ZSTD_BtGetAllMatches_extDict(zc, ip, iHighLimit, maxNbAttempts, 5, matches, minMatchLen); + case 7: + case 6: return ZSTD_BtGetAllMatches_extDict(zc, ip, iHighLimit, maxNbAttempts, 6, matches, minMatchLen); + } +} + +/*-******************************* +* Optimal parser +*********************************/ +FORCE_INLINE +void ZSTD_compressBlock_opt_generic(ZSTD_CCtx *ctx, const void *src, size_t srcSize, const int ultra) +{ + seqStore_t *seqStorePtr = &(ctx->seqStore); + const BYTE *const istart = (const BYTE *)src; + const BYTE *ip = istart; + const BYTE *anchor = istart; + const BYTE *const iend = istart + srcSize; + const BYTE *const ilimit = iend - 8; + const BYTE *const base = ctx->base; + const BYTE *const prefixStart = base + ctx->dictLimit; + + const U32 maxSearches = 1U << ctx->params.cParams.searchLog; + const U32 sufficient_len = ctx->params.cParams.targetLength; + const U32 mls = ctx->params.cParams.searchLength; + const U32 minMatch = (ctx->params.cParams.searchLength == 3) ? 3 : 4; + + ZSTD_optimal_t *opt = seqStorePtr->priceTable; + ZSTD_match_t *matches = seqStorePtr->matchTable; + const BYTE *inr; + U32 offset, rep[ZSTD_REP_NUM]; + + /* init */ + ctx->nextToUpdate3 = ctx->nextToUpdate; + ZSTD_rescaleFreqs(seqStorePtr, (const BYTE *)src, srcSize); + ip += (ip == prefixStart); + { + U32 i; + for (i = 0; i < ZSTD_REP_NUM; i++) + rep[i] = ctx->rep[i]; + } + + /* Match Loop */ + while (ip < ilimit) { + U32 cur, match_num, last_pos, litlen, price; + U32 u, mlen, best_mlen, best_off, litLength; + memset(opt, 0, sizeof(ZSTD_optimal_t)); + last_pos = 0; + litlen = (U32)(ip - anchor); + + /* check repCode */ + { + U32 i, last_i = ZSTD_REP_CHECK + (ip == anchor); + for (i = (ip == anchor); i < last_i; i++) { + const S32 repCur = (i == ZSTD_REP_MOVE_OPT) ? (rep[0] - 1) : rep[i]; + if ((repCur > 0) && (repCur < (S32)(ip - prefixStart)) && + (ZSTD_readMINMATCH(ip, minMatch) == ZSTD_readMINMATCH(ip - repCur, minMatch))) { + mlen = (U32)ZSTD_count(ip + minMatch, ip + minMatch - repCur, iend) + minMatch; + if (mlen > sufficient_len || mlen >= ZSTD_OPT_NUM) { + best_mlen = mlen; + best_off = i; + cur = 0; + last_pos = 1; + goto _storeSequence; + } + best_off = i - (ip == anchor); + do { + price = ZSTD_getPrice(seqStorePtr, litlen, anchor, best_off, mlen - MINMATCH, ultra); + if (mlen > last_pos || price < opt[mlen].price) + SET_PRICE(mlen, mlen, i, litlen, price); /* note : macro modifies last_pos */ + mlen--; + } while (mlen >= minMatch); + } + } + } + + match_num = ZSTD_BtGetAllMatches_selectMLS(ctx, ip, iend, maxSearches, mls, matches, minMatch); + + if (!last_pos && !match_num) { + ip++; + continue; + } + + if (match_num && (matches[match_num - 1].len > sufficient_len || matches[match_num - 1].len >= ZSTD_OPT_NUM)) { + best_mlen = matches[match_num - 1].len; + best_off = matches[match_num - 1].off; + cur = 0; + last_pos = 1; + goto _storeSequence; + } + + /* set prices using matches at position = 0 */ + best_mlen = (last_pos) ? last_pos : minMatch; + for (u = 0; u < match_num; u++) { + mlen = (u > 0) ? matches[u - 1].len + 1 : best_mlen; + best_mlen = matches[u].len; + while (mlen <= best_mlen) { + price = ZSTD_getPrice(seqStorePtr, litlen, anchor, matches[u].off - 1, mlen - MINMATCH, ultra); + if (mlen > last_pos || price < opt[mlen].price) + SET_PRICE(mlen, mlen, matches[u].off, litlen, price); /* note : macro modifies last_pos */ + mlen++; + } + } + + if (last_pos < minMatch) { + ip++; + continue; + } + + /* initialize opt[0] */ + { + U32 i; + for (i = 0; i < ZSTD_REP_NUM; i++) + opt[0].rep[i] = rep[i]; + } + opt[0].mlen = 1; + opt[0].litlen = litlen; + + /* check further positions */ + for (cur = 1; cur <= last_pos; cur++) { + inr = ip + cur; + + if (opt[cur - 1].mlen == 1) { + litlen = opt[cur - 1].litlen + 1; + if (cur > litlen) { + price = opt[cur - litlen].price + ZSTD_getLiteralPrice(seqStorePtr, litlen, inr - litlen); + } else + price = ZSTD_getLiteralPrice(seqStorePtr, litlen, anchor); + } else { + litlen = 1; + price = opt[cur - 1].price + ZSTD_getLiteralPrice(seqStorePtr, litlen, inr - 1); + } + + if (cur > last_pos || price <= opt[cur].price) + SET_PRICE(cur, 1, 0, litlen, price); + + if (cur == last_pos) + break; + + if (inr > ilimit) /* last match must start at a minimum distance of 8 from oend */ + continue; + + mlen = opt[cur].mlen; + if (opt[cur].off > ZSTD_REP_MOVE_OPT) { + opt[cur].rep[2] = opt[cur - mlen].rep[1]; + opt[cur].rep[1] = opt[cur - mlen].rep[0]; + opt[cur].rep[0] = opt[cur].off - ZSTD_REP_MOVE_OPT; + } else { + opt[cur].rep[2] = (opt[cur].off > 1) ? opt[cur - mlen].rep[1] : opt[cur - mlen].rep[2]; + opt[cur].rep[1] = (opt[cur].off > 0) ? opt[cur - mlen].rep[0] : opt[cur - mlen].rep[1]; + opt[cur].rep[0] = + ((opt[cur].off == ZSTD_REP_MOVE_OPT) && (mlen != 1)) ? (opt[cur - mlen].rep[0] - 1) : (opt[cur - mlen].rep[opt[cur].off]); + } + + best_mlen = minMatch; + { + U32 i, last_i = ZSTD_REP_CHECK + (mlen != 1); + for (i = (opt[cur].mlen != 1); i < last_i; i++) { /* check rep */ + const S32 repCur = (i == ZSTD_REP_MOVE_OPT) ? (opt[cur].rep[0] - 1) : opt[cur].rep[i]; + if ((repCur > 0) && (repCur < (S32)(inr - prefixStart)) && + (ZSTD_readMINMATCH(inr, minMatch) == ZSTD_readMINMATCH(inr - repCur, minMatch))) { + mlen = (U32)ZSTD_count(inr + minMatch, inr + minMatch - repCur, iend) + minMatch; + + if (mlen > sufficient_len || cur + mlen >= ZSTD_OPT_NUM) { + best_mlen = mlen; + best_off = i; + last_pos = cur + 1; + goto _storeSequence; + } + + best_off = i - (opt[cur].mlen != 1); + if (mlen > best_mlen) + best_mlen = mlen; + + do { + if (opt[cur].mlen == 1) { + litlen = opt[cur].litlen; + if (cur > litlen) { + price = opt[cur - litlen].price + ZSTD_getPrice(seqStorePtr, litlen, inr - litlen, + best_off, mlen - MINMATCH, ultra); + } else + price = ZSTD_getPrice(seqStorePtr, litlen, anchor, best_off, mlen - MINMATCH, ultra); + } else { + litlen = 0; + price = opt[cur].price + ZSTD_getPrice(seqStorePtr, 0, NULL, best_off, mlen - MINMATCH, ultra); + } + + if (cur + mlen > last_pos || price <= opt[cur + mlen].price) + SET_PRICE(cur + mlen, mlen, i, litlen, price); + mlen--; + } while (mlen >= minMatch); + } + } + } + + match_num = ZSTD_BtGetAllMatches_selectMLS(ctx, inr, iend, maxSearches, mls, matches, best_mlen); + + if (match_num > 0 && (matches[match_num - 1].len > sufficient_len || cur + matches[match_num - 1].len >= ZSTD_OPT_NUM)) { + best_mlen = matches[match_num - 1].len; + best_off = matches[match_num - 1].off; + last_pos = cur + 1; + goto _storeSequence; + } + + /* set prices using matches at position = cur */ + for (u = 0; u < match_num; u++) { + mlen = (u > 0) ? matches[u - 1].len + 1 : best_mlen; + best_mlen = matches[u].len; + + while (mlen <= best_mlen) { + if (opt[cur].mlen == 1) { + litlen = opt[cur].litlen; + if (cur > litlen) + price = opt[cur - litlen].price + ZSTD_getPrice(seqStorePtr, litlen, ip + cur - litlen, + matches[u].off - 1, mlen - MINMATCH, ultra); + else + price = ZSTD_getPrice(seqStorePtr, litlen, anchor, matches[u].off - 1, mlen - MINMATCH, ultra); + } else { + litlen = 0; + price = opt[cur].price + ZSTD_getPrice(seqStorePtr, 0, NULL, matches[u].off - 1, mlen - MINMATCH, ultra); + } + + if (cur + mlen > last_pos || (price < opt[cur + mlen].price)) + SET_PRICE(cur + mlen, mlen, matches[u].off, litlen, price); + + mlen++; + } + } + } + + best_mlen = opt[last_pos].mlen; + best_off = opt[last_pos].off; + cur = last_pos - best_mlen; + + /* store sequence */ +_storeSequence: /* cur, last_pos, best_mlen, best_off have to be set */ + opt[0].mlen = 1; + + while (1) { + mlen = opt[cur].mlen; + offset = opt[cur].off; + opt[cur].mlen = best_mlen; + opt[cur].off = best_off; + best_mlen = mlen; + best_off = offset; + if (mlen > cur) + break; + cur -= mlen; + } + + for (u = 0; u <= last_pos;) { + u += opt[u].mlen; + } + + for (cur = 0; cur < last_pos;) { + mlen = opt[cur].mlen; + if (mlen == 1) { + ip++; + cur++; + continue; + } + offset = opt[cur].off; + cur += mlen; + litLength = (U32)(ip - anchor); + + if (offset > ZSTD_REP_MOVE_OPT) { + rep[2] = rep[1]; + rep[1] = rep[0]; + rep[0] = offset - ZSTD_REP_MOVE_OPT; + offset--; + } else { + if (offset != 0) { + best_off = (offset == ZSTD_REP_MOVE_OPT) ? (rep[0] - 1) : (rep[offset]); + if (offset != 1) + rep[2] = rep[1]; + rep[1] = rep[0]; + rep[0] = best_off; + } + if (litLength == 0) + offset--; + } + + ZSTD_updatePrice(seqStorePtr, litLength, anchor, offset, mlen - MINMATCH); + ZSTD_storeSeq(seqStorePtr, litLength, anchor, offset, mlen - MINMATCH); + anchor = ip = ip + mlen; + } + } /* for (cur=0; cur < last_pos; ) */ + + /* Save reps for next block */ + { + int i; + for (i = 0; i < ZSTD_REP_NUM; i++) + ctx->repToConfirm[i] = rep[i]; + } + + /* Last Literals */ + { + size_t const lastLLSize = iend - anchor; + memcpy(seqStorePtr->lit, anchor, lastLLSize); + seqStorePtr->lit += lastLLSize; + } +} + +FORCE_INLINE +void ZSTD_compressBlock_opt_extDict_generic(ZSTD_CCtx *ctx, const void *src, size_t srcSize, const int ultra) +{ + seqStore_t *seqStorePtr = &(ctx->seqStore); + const BYTE *const istart = (const BYTE *)src; + const BYTE *ip = istart; + const BYTE *anchor = istart; + const BYTE *const iend = istart + srcSize; + const BYTE *const ilimit = iend - 8; + const BYTE *const base = ctx->base; + const U32 lowestIndex = ctx->lowLimit; + const U32 dictLimit = ctx->dictLimit; + const BYTE *const prefixStart = base + dictLimit; + const BYTE *const dictBase = ctx->dictBase; + const BYTE *const dictEnd = dictBase + dictLimit; + + const U32 maxSearches = 1U << ctx->params.cParams.searchLog; + const U32 sufficient_len = ctx->params.cParams.targetLength; + const U32 mls = ctx->params.cParams.searchLength; + const U32 minMatch = (ctx->params.cParams.searchLength == 3) ? 3 : 4; + + ZSTD_optimal_t *opt = seqStorePtr->priceTable; + ZSTD_match_t *matches = seqStorePtr->matchTable; + const BYTE *inr; + + /* init */ + U32 offset, rep[ZSTD_REP_NUM]; + { + U32 i; + for (i = 0; i < ZSTD_REP_NUM; i++) + rep[i] = ctx->rep[i]; + } + + ctx->nextToUpdate3 = ctx->nextToUpdate; + ZSTD_rescaleFreqs(seqStorePtr, (const BYTE *)src, srcSize); + ip += (ip == prefixStart); + + /* Match Loop */ + while (ip < ilimit) { + U32 cur, match_num, last_pos, litlen, price; + U32 u, mlen, best_mlen, best_off, litLength; + U32 curr = (U32)(ip - base); + memset(opt, 0, sizeof(ZSTD_optimal_t)); + last_pos = 0; + opt[0].litlen = (U32)(ip - anchor); + + /* check repCode */ + { + U32 i, last_i = ZSTD_REP_CHECK + (ip == anchor); + for (i = (ip == anchor); i < last_i; i++) { + const S32 repCur = (i == ZSTD_REP_MOVE_OPT) ? (rep[0] - 1) : rep[i]; + const U32 repIndex = (U32)(curr - repCur); + const BYTE *const repBase = repIndex < dictLimit ? dictBase : base; + const BYTE *const repMatch = repBase + repIndex; + if ((repCur > 0 && repCur <= (S32)curr) && + (((U32)((dictLimit - 1) - repIndex) >= 3) & (repIndex > lowestIndex)) /* intentional overflow */ + && (ZSTD_readMINMATCH(ip, minMatch) == ZSTD_readMINMATCH(repMatch, minMatch))) { + /* repcode detected we should take it */ + const BYTE *const repEnd = repIndex < dictLimit ? dictEnd : iend; + mlen = (U32)ZSTD_count_2segments(ip + minMatch, repMatch + minMatch, iend, repEnd, prefixStart) + minMatch; + + if (mlen > sufficient_len || mlen >= ZSTD_OPT_NUM) { + best_mlen = mlen; + best_off = i; + cur = 0; + last_pos = 1; + goto _storeSequence; + } + + best_off = i - (ip == anchor); + litlen = opt[0].litlen; + do { + price = ZSTD_getPrice(seqStorePtr, litlen, anchor, best_off, mlen - MINMATCH, ultra); + if (mlen > last_pos || price < opt[mlen].price) + SET_PRICE(mlen, mlen, i, litlen, price); /* note : macro modifies last_pos */ + mlen--; + } while (mlen >= minMatch); + } + } + } + + match_num = ZSTD_BtGetAllMatches_selectMLS_extDict(ctx, ip, iend, maxSearches, mls, matches, minMatch); /* first search (depth 0) */ + + if (!last_pos && !match_num) { + ip++; + continue; + } + + { + U32 i; + for (i = 0; i < ZSTD_REP_NUM; i++) + opt[0].rep[i] = rep[i]; + } + opt[0].mlen = 1; + + if (match_num && (matches[match_num - 1].len > sufficient_len || matches[match_num - 1].len >= ZSTD_OPT_NUM)) { + best_mlen = matches[match_num - 1].len; + best_off = matches[match_num - 1].off; + cur = 0; + last_pos = 1; + goto _storeSequence; + } + + best_mlen = (last_pos) ? last_pos : minMatch; + + /* set prices using matches at position = 0 */ + for (u = 0; u < match_num; u++) { + mlen = (u > 0) ? matches[u - 1].len + 1 : best_mlen; + best_mlen = matches[u].len; + litlen = opt[0].litlen; + while (mlen <= best_mlen) { + price = ZSTD_getPrice(seqStorePtr, litlen, anchor, matches[u].off - 1, mlen - MINMATCH, ultra); + if (mlen > last_pos || price < opt[mlen].price) + SET_PRICE(mlen, mlen, matches[u].off, litlen, price); + mlen++; + } + } + + if (last_pos < minMatch) { + ip++; + continue; + } + + /* check further positions */ + for (cur = 1; cur <= last_pos; cur++) { + inr = ip + cur; + + if (opt[cur - 1].mlen == 1) { + litlen = opt[cur - 1].litlen + 1; + if (cur > litlen) { + price = opt[cur - litlen].price + ZSTD_getLiteralPrice(seqStorePtr, litlen, inr - litlen); + } else + price = ZSTD_getLiteralPrice(seqStorePtr, litlen, anchor); + } else { + litlen = 1; + price = opt[cur - 1].price + ZSTD_getLiteralPrice(seqStorePtr, litlen, inr - 1); + } + + if (cur > last_pos || price <= opt[cur].price) + SET_PRICE(cur, 1, 0, litlen, price); + + if (cur == last_pos) + break; + + if (inr > ilimit) /* last match must start at a minimum distance of 8 from oend */ + continue; + + mlen = opt[cur].mlen; + if (opt[cur].off > ZSTD_REP_MOVE_OPT) { + opt[cur].rep[2] = opt[cur - mlen].rep[1]; + opt[cur].rep[1] = opt[cur - mlen].rep[0]; + opt[cur].rep[0] = opt[cur].off - ZSTD_REP_MOVE_OPT; + } else { + opt[cur].rep[2] = (opt[cur].off > 1) ? opt[cur - mlen].rep[1] : opt[cur - mlen].rep[2]; + opt[cur].rep[1] = (opt[cur].off > 0) ? opt[cur - mlen].rep[0] : opt[cur - mlen].rep[1]; + opt[cur].rep[0] = + ((opt[cur].off == ZSTD_REP_MOVE_OPT) && (mlen != 1)) ? (opt[cur - mlen].rep[0] - 1) : (opt[cur - mlen].rep[opt[cur].off]); + } + + best_mlen = minMatch; + { + U32 i, last_i = ZSTD_REP_CHECK + (mlen != 1); + for (i = (mlen != 1); i < last_i; i++) { + const S32 repCur = (i == ZSTD_REP_MOVE_OPT) ? (opt[cur].rep[0] - 1) : opt[cur].rep[i]; + const U32 repIndex = (U32)(curr + cur - repCur); + const BYTE *const repBase = repIndex < dictLimit ? dictBase : base; + const BYTE *const repMatch = repBase + repIndex; + if ((repCur > 0 && repCur <= (S32)(curr + cur)) && + (((U32)((dictLimit - 1) - repIndex) >= 3) & (repIndex > lowestIndex)) /* intentional overflow */ + && (ZSTD_readMINMATCH(inr, minMatch) == ZSTD_readMINMATCH(repMatch, minMatch))) { + /* repcode detected */ + const BYTE *const repEnd = repIndex < dictLimit ? dictEnd : iend; + mlen = (U32)ZSTD_count_2segments(inr + minMatch, repMatch + minMatch, iend, repEnd, prefixStart) + minMatch; + + if (mlen > sufficient_len || cur + mlen >= ZSTD_OPT_NUM) { + best_mlen = mlen; + best_off = i; + last_pos = cur + 1; + goto _storeSequence; + } + + best_off = i - (opt[cur].mlen != 1); + if (mlen > best_mlen) + best_mlen = mlen; + + do { + if (opt[cur].mlen == 1) { + litlen = opt[cur].litlen; + if (cur > litlen) { + price = opt[cur - litlen].price + ZSTD_getPrice(seqStorePtr, litlen, inr - litlen, + best_off, mlen - MINMATCH, ultra); + } else + price = ZSTD_getPrice(seqStorePtr, litlen, anchor, best_off, mlen - MINMATCH, ultra); + } else { + litlen = 0; + price = opt[cur].price + ZSTD_getPrice(seqStorePtr, 0, NULL, best_off, mlen - MINMATCH, ultra); + } + + if (cur + mlen > last_pos || price <= opt[cur + mlen].price) + SET_PRICE(cur + mlen, mlen, i, litlen, price); + mlen--; + } while (mlen >= minMatch); + } + } + } + + match_num = ZSTD_BtGetAllMatches_selectMLS_extDict(ctx, inr, iend, maxSearches, mls, matches, minMatch); + + if (match_num > 0 && (matches[match_num - 1].len > sufficient_len || cur + matches[match_num - 1].len >= ZSTD_OPT_NUM)) { + best_mlen = matches[match_num - 1].len; + best_off = matches[match_num - 1].off; + last_pos = cur + 1; + goto _storeSequence; + } + + /* set prices using matches at position = cur */ + for (u = 0; u < match_num; u++) { + mlen = (u > 0) ? matches[u - 1].len + 1 : best_mlen; + best_mlen = matches[u].len; + + while (mlen <= best_mlen) { + if (opt[cur].mlen == 1) { + litlen = opt[cur].litlen; + if (cur > litlen) + price = opt[cur - litlen].price + ZSTD_getPrice(seqStorePtr, litlen, ip + cur - litlen, + matches[u].off - 1, mlen - MINMATCH, ultra); + else + price = ZSTD_getPrice(seqStorePtr, litlen, anchor, matches[u].off - 1, mlen - MINMATCH, ultra); + } else { + litlen = 0; + price = opt[cur].price + ZSTD_getPrice(seqStorePtr, 0, NULL, matches[u].off - 1, mlen - MINMATCH, ultra); + } + + if (cur + mlen > last_pos || (price < opt[cur + mlen].price)) + SET_PRICE(cur + mlen, mlen, matches[u].off, litlen, price); + + mlen++; + } + } + } /* for (cur = 1; cur <= last_pos; cur++) */ + + best_mlen = opt[last_pos].mlen; + best_off = opt[last_pos].off; + cur = last_pos - best_mlen; + + /* store sequence */ +_storeSequence: /* cur, last_pos, best_mlen, best_off have to be set */ + opt[0].mlen = 1; + + while (1) { + mlen = opt[cur].mlen; + offset = opt[cur].off; + opt[cur].mlen = best_mlen; + opt[cur].off = best_off; + best_mlen = mlen; + best_off = offset; + if (mlen > cur) + break; + cur -= mlen; + } + + for (u = 0; u <= last_pos;) { + u += opt[u].mlen; + } + + for (cur = 0; cur < last_pos;) { + mlen = opt[cur].mlen; + if (mlen == 1) { + ip++; + cur++; + continue; + } + offset = opt[cur].off; + cur += mlen; + litLength = (U32)(ip - anchor); + + if (offset > ZSTD_REP_MOVE_OPT) { + rep[2] = rep[1]; + rep[1] = rep[0]; + rep[0] = offset - ZSTD_REP_MOVE_OPT; + offset--; + } else { + if (offset != 0) { + best_off = (offset == ZSTD_REP_MOVE_OPT) ? (rep[0] - 1) : (rep[offset]); + if (offset != 1) + rep[2] = rep[1]; + rep[1] = rep[0]; + rep[0] = best_off; + } + + if (litLength == 0) + offset--; + } + + ZSTD_updatePrice(seqStorePtr, litLength, anchor, offset, mlen - MINMATCH); + ZSTD_storeSeq(seqStorePtr, litLength, anchor, offset, mlen - MINMATCH); + anchor = ip = ip + mlen; + } + } /* for (cur=0; cur < last_pos; ) */ + + /* Save reps for next block */ + { + int i; + for (i = 0; i < ZSTD_REP_NUM; i++) + ctx->repToConfirm[i] = rep[i]; + } + + /* Last Literals */ + { + size_t lastLLSize = iend - anchor; + memcpy(seqStorePtr->lit, anchor, lastLLSize); + seqStorePtr->lit += lastLLSize; + } +} + +#endif /* ZSTD_OPT_H_91842398743 */ -- cgit v1.2.3 From dea632cadd129d1901efde6b1167a6250b2157fb Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Mon, 7 Aug 2017 10:31:20 -0700 Subject: lib/mpi: fix build with clang Use just @ to denote comments which works with gcc and clang. Otherwise clang reports an escape sequence error: error: invalid % escape in inline assembly string Use %0-%3 as operand references, this avoids: error: invalid operand in inline asm: 'umull ${1:r}, ${0:r}, ${2:r}, ${3:r}' Also remove superfluous casts on output operands to avoid warnings such as: warning: invalid use of a cast in an inline asm context requiring an l-value Signed-off-by: Stefan Agner Acked-by: Arnd Bergmann Signed-off-by: Herbert Xu --- lib/mpi/longlong.h | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) (limited to 'lib') diff --git a/lib/mpi/longlong.h b/lib/mpi/longlong.h index 93336502af08..57fd45ab7af1 100644 --- a/lib/mpi/longlong.h +++ b/lib/mpi/longlong.h @@ -176,8 +176,8 @@ extern UDItype __udiv_qrnnd(UDItype *, UDItype, UDItype, UDItype); #define add_ssaaaa(sh, sl, ah, al, bh, bl) \ __asm__ ("adds %1, %4, %5\n" \ "adc %0, %2, %3" \ - : "=r" ((USItype)(sh)), \ - "=&r" ((USItype)(sl)) \ + : "=r" (sh), \ + "=&r" (sl) \ : "%r" ((USItype)(ah)), \ "rI" ((USItype)(bh)), \ "%r" ((USItype)(al)), \ @@ -185,15 +185,15 @@ extern UDItype __udiv_qrnnd(UDItype *, UDItype, UDItype, UDItype); #define sub_ddmmss(sh, sl, ah, al, bh, bl) \ __asm__ ("subs %1, %4, %5\n" \ "sbc %0, %2, %3" \ - : "=r" ((USItype)(sh)), \ - "=&r" ((USItype)(sl)) \ + : "=r" (sh), \ + "=&r" (sl) \ : "r" ((USItype)(ah)), \ "rI" ((USItype)(bh)), \ "r" ((USItype)(al)), \ "rI" ((USItype)(bl))) #if defined __ARM_ARCH_2__ || defined __ARM_ARCH_3__ #define umul_ppmm(xh, xl, a, b) \ - __asm__ ("%@ Inlined umul_ppmm\n" \ + __asm__ ("@ Inlined umul_ppmm\n" \ "mov %|r0, %2, lsr #16 @ AAAA\n" \ "mov %|r2, %3, lsr #16 @ BBBB\n" \ "bic %|r1, %2, %|r0, lsl #16 @ aaaa\n" \ @@ -206,19 +206,19 @@ extern UDItype __udiv_qrnnd(UDItype *, UDItype, UDItype, UDItype); "addcs %|r2, %|r2, #65536\n" \ "adds %1, %|r1, %|r0, lsl #16\n" \ "adc %0, %|r2, %|r0, lsr #16" \ - : "=&r" ((USItype)(xh)), \ - "=r" ((USItype)(xl)) \ + : "=&r" (xh), \ + "=r" (xl) \ : "r" ((USItype)(a)), \ "r" ((USItype)(b)) \ : "r0", "r1", "r2") #else #define umul_ppmm(xh, xl, a, b) \ - __asm__ ("%@ Inlined umul_ppmm\n" \ - "umull %r1, %r0, %r2, %r3" \ - : "=&r" ((USItype)(xh)), \ - "=&r" ((USItype)(xl)) \ + __asm__ ("@ Inlined umul_ppmm\n" \ + "umull %1, %0, %2, %3" \ + : "=&r" (xh), \ + "=&r" (xl) \ : "r" ((USItype)(a)), \ - "r" ((USItype)(b)) \ + "r" ((USItype)(b)) \ : "r0", "r1") #endif #define UMUL_TIME 20 -- cgit v1.2.3 From d0541b0fa64b36665d6261079974a26943c75009 Mon Sep 17 00:00:00 2001 From: Byungchul Park Date: Thu, 17 Aug 2017 17:57:39 +0900 Subject: locking/lockdep: Make CONFIG_LOCKDEP_CROSSRELEASE part of CONFIG_PROVE_LOCKING Crossrelease support added the CONFIG_LOCKDEP_CROSSRELEASE and CONFIG_LOCKDEP_COMPLETE options. It makes little sense to enable them when PROVE_LOCKING is disabled. Make them non-interative options and part of PROVE_LOCKING to simplify the UI. Signed-off-by: Byungchul Park Cc: Linus Torvalds Cc: Peter Zijlstra Cc: Thomas Gleixner Cc: kernel-team@lge.com Link: http://lkml.kernel.org/r/1502960261-16206-1-git-send-email-byungchul.park@lge.com Signed-off-by: Ingo Molnar --- lib/Kconfig.debug | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) (limited to 'lib') diff --git a/lib/Kconfig.debug b/lib/Kconfig.debug index ebd40d3b0b28..1ad7f1b9160e 100644 --- a/lib/Kconfig.debug +++ b/lib/Kconfig.debug @@ -1081,6 +1081,8 @@ config PROVE_LOCKING select DEBUG_MUTEXES select DEBUG_RT_MUTEXES if RT_MUTEXES select DEBUG_LOCK_ALLOC + select LOCKDEP_CROSSRELEASE + select LOCKDEP_COMPLETE select TRACE_IRQFLAGS default n help @@ -1152,8 +1154,6 @@ config LOCK_STAT config LOCKDEP_CROSSRELEASE bool "Lock debugging: make lockdep work for crosslocks" - depends on PROVE_LOCKING - default n help This makes lockdep work for crosslock which is a lock allowed to be released in a different context from the acquisition context. @@ -1164,9 +1164,6 @@ config LOCKDEP_CROSSRELEASE config LOCKDEP_COMPLETE bool "Lock debugging: allow completions to use deadlock detector" - depends on PROVE_LOCKING - select LOCKDEP_CROSSRELEASE - default n help A deadlock caused by wait_for_completion() and complete() can be detected by lockdep using crossrelease feature. -- cgit v1.2.3 From 0f0a22260d613b4ee3f483ee1ea6fa27f92a9e40 Mon Sep 17 00:00:00 2001 From: Byungchul Park Date: Thu, 17 Aug 2017 17:57:40 +0900 Subject: locking/lockdep: Reword title of LOCKDEP_CROSSRELEASE config Lockdep doesn't have to be made to work with crossrelease and just works with them. Reword the title so that what the option does is clear. Signed-off-by: Byungchul Park Cc: Linus Torvalds Cc: Peter Zijlstra Cc: Thomas Gleixner Cc: kernel-team@lge.com Link: http://lkml.kernel.org/r/1502960261-16206-2-git-send-email-byungchul.park@lge.com Signed-off-by: Ingo Molnar --- lib/Kconfig.debug | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'lib') diff --git a/lib/Kconfig.debug b/lib/Kconfig.debug index 1ad7f1b9160e..8fb8a206db12 100644 --- a/lib/Kconfig.debug +++ b/lib/Kconfig.debug @@ -1153,7 +1153,7 @@ config LOCK_STAT (CONFIG_LOCKDEP defines "acquire" and "release" events.) config LOCKDEP_CROSSRELEASE - bool "Lock debugging: make lockdep work for crosslocks" + bool "Lock debugging: enable cross-locking checks in lockdep" help This makes lockdep work for crosslock which is a lock allowed to be released in a different context from the acquisition context. -- cgit v1.2.3 From ea3f2c0fdfbb180f142cdbc0d1f055c7b9a5e63e Mon Sep 17 00:00:00 2001 From: Byungchul Park Date: Thu, 17 Aug 2017 17:57:41 +0900 Subject: locking/lockdep: Rename CONFIG_LOCKDEP_COMPLETE to CONFIG_LOCKDEP_COMPLETIONS 'complete' is an adjective and LOCKDEP_COMPLETE sounds like 'lockdep is complete', so pick a better name that uses a noun. Signed-off-by: Byungchul Park Cc: Linus Torvalds Cc: Peter Zijlstra Cc: Thomas Gleixner Cc: kernel-team@lge.com Link: http://lkml.kernel.org/r/1502960261-16206-3-git-send-email-byungchul.park@lge.com Signed-off-by: Ingo Molnar --- include/linux/completion.h | 8 ++++---- lib/Kconfig.debug | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'lib') diff --git a/include/linux/completion.h b/include/linux/completion.h index 9bcebf509b4d..791f053f28b7 100644 --- a/include/linux/completion.h +++ b/include/linux/completion.h @@ -9,7 +9,7 @@ */ #include -#ifdef CONFIG_LOCKDEP_COMPLETE +#ifdef CONFIG_LOCKDEP_COMPLETIONS #include #endif @@ -28,12 +28,12 @@ struct completion { unsigned int done; wait_queue_head_t wait; -#ifdef CONFIG_LOCKDEP_COMPLETE +#ifdef CONFIG_LOCKDEP_COMPLETIONS struct lockdep_map_cross map; #endif }; -#ifdef CONFIG_LOCKDEP_COMPLETE +#ifdef CONFIG_LOCKDEP_COMPLETIONS static inline void complete_acquire(struct completion *x) { lock_acquire_exclusive((struct lockdep_map *)&x->map, 0, 0, NULL, _RET_IP_); @@ -64,7 +64,7 @@ static inline void complete_release(struct completion *x) {} static inline void complete_release_commit(struct completion *x) {} #endif -#ifdef CONFIG_LOCKDEP_COMPLETE +#ifdef CONFIG_LOCKDEP_COMPLETIONS #define COMPLETION_INITIALIZER(work) \ { 0, __WAIT_QUEUE_HEAD_INITIALIZER((work).wait), \ STATIC_CROSS_LOCKDEP_MAP_INIT("(complete)" #work, &(work)) } diff --git a/lib/Kconfig.debug b/lib/Kconfig.debug index 8fb8a206db12..a0e60d56303e 100644 --- a/lib/Kconfig.debug +++ b/lib/Kconfig.debug @@ -1082,7 +1082,7 @@ config PROVE_LOCKING select DEBUG_RT_MUTEXES if RT_MUTEXES select DEBUG_LOCK_ALLOC select LOCKDEP_CROSSRELEASE - select LOCKDEP_COMPLETE + select LOCKDEP_COMPLETIONS select TRACE_IRQFLAGS default n help @@ -1162,7 +1162,7 @@ config LOCKDEP_CROSSRELEASE such as page locks or completions can use the lock correctness detector, lockdep. -config LOCKDEP_COMPLETE +config LOCKDEP_COMPLETIONS bool "Lock debugging: allow completions to use deadlock detector" help A deadlock caused by wait_for_completion() and complete() can be -- cgit v1.2.3 From e26f34a407aec9c65bce2bc0c838fabe4f051fc6 Mon Sep 17 00:00:00 2001 From: Ingo Molnar Date: Thu, 17 Aug 2017 12:48:29 +0200 Subject: locking/lockdep: Make CONFIG_LOCKDEP_CROSSRELEASE and CONFIG_LOCKDEP_COMPLETIONS truly non-interactive The syntax to turn Kconfig options into non-interactive ones is to not offer interactive prompt help texts. Remove them. Cc: Boqun Feng Cc: Byungchul Park Cc: Lai Jiangshan Cc: Linus Torvalds Cc: Peter Zijlstra Cc: Tejun Heo Cc: Thomas Gleixner Cc: linux-kernel@vger.kernel.org Signed-off-by: Ingo Molnar --- lib/Kconfig.debug | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'lib') diff --git a/lib/Kconfig.debug b/lib/Kconfig.debug index a0e60d56303e..eb0f316b731a 100644 --- a/lib/Kconfig.debug +++ b/lib/Kconfig.debug @@ -1153,7 +1153,7 @@ config LOCK_STAT (CONFIG_LOCKDEP defines "acquire" and "release" events.) config LOCKDEP_CROSSRELEASE - bool "Lock debugging: enable cross-locking checks in lockdep" + bool help This makes lockdep work for crosslock which is a lock allowed to be released in a different context from the acquisition context. @@ -1163,7 +1163,7 @@ config LOCKDEP_CROSSRELEASE detector, lockdep. config LOCKDEP_COMPLETIONS - bool "Lock debugging: allow completions to use deadlock detector" + bool help A deadlock caused by wait_for_completion() and complete() can be detected by lockdep using crossrelease feature. -- cgit v1.2.3 From 7edaeb6841dfb27e362288ab8466ebdc4972e867 Mon Sep 17 00:00:00 2001 From: Thomas Gleixner Date: Tue, 15 Aug 2017 09:50:13 +0200 Subject: kernel/watchdog: Prevent false positives with turbo modes The hardlockup detector on x86 uses a performance counter based on unhalted CPU cycles and a periodic hrtimer. The hrtimer period is about 2/5 of the performance counter period, so the hrtimer should fire 2-3 times before the performance counter NMI fires. The NMI code checks whether the hrtimer fired since the last invocation. If not, it assumess a hard lockup. The calculation of those periods is based on the nominal CPU frequency. Turbo modes increase the CPU clock frequency and therefore shorten the period of the perf/NMI watchdog. With extreme Turbo-modes (3x nominal frequency) the perf/NMI period is shorter than the hrtimer period which leads to false positives. A simple fix would be to shorten the hrtimer period, but that comes with the side effect of more frequent hrtimer and softlockup thread wakeups, which is not desired. Implement a low pass filter, which checks the perf/NMI period against kernel time. If the perf/NMI fires before 4/5 of the watchdog period has elapsed then the event is ignored and postponed to the next perf/NMI. That solves the problem and avoids the overhead of shorter hrtimer periods and more frequent softlockup thread wakeups. Fixes: 58687acba592 ("lockup_detector: Combine nmi_watchdog and softlockup detector") Reported-and-tested-by: Kan Liang Signed-off-by: Thomas Gleixner Cc: dzickus@redhat.com Cc: prarit@redhat.com Cc: ak@linux.intel.com Cc: babu.moger@oracle.com Cc: peterz@infradead.org Cc: eranian@google.com Cc: acme@redhat.com Cc: stable@vger.kernel.org Cc: atomlin@redhat.com Cc: akpm@linux-foundation.org Cc: torvalds@linux-foundation.org Link: http://lkml.kernel.org/r/alpine.DEB.2.20.1708150931310.1886@nanos --- arch/x86/Kconfig | 1 + include/linux/nmi.h | 8 +++++++ kernel/watchdog.c | 1 + kernel/watchdog_hld.c | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++ lib/Kconfig.debug | 7 ++++++ 5 files changed, 76 insertions(+) (limited to 'lib') diff --git a/arch/x86/Kconfig b/arch/x86/Kconfig index 781521b7cf9e..9101bfc85539 100644 --- a/arch/x86/Kconfig +++ b/arch/x86/Kconfig @@ -100,6 +100,7 @@ config X86 select GENERIC_STRNCPY_FROM_USER select GENERIC_STRNLEN_USER select GENERIC_TIME_VSYSCALL + select HARDLOCKUP_CHECK_TIMESTAMP if X86_64 select HAVE_ACPI_APEI if ACPI select HAVE_ACPI_APEI_NMI if ACPI select HAVE_ALIGNED_STRUCT_PAGE if SLUB diff --git a/include/linux/nmi.h b/include/linux/nmi.h index 8aa01fd859fb..a36abe2da13e 100644 --- a/include/linux/nmi.h +++ b/include/linux/nmi.h @@ -168,6 +168,14 @@ extern int sysctl_hardlockup_all_cpu_backtrace; #define sysctl_softlockup_all_cpu_backtrace 0 #define sysctl_hardlockup_all_cpu_backtrace 0 #endif + +#if defined(CONFIG_HARDLOCKUP_CHECK_TIMESTAMP) && \ + defined(CONFIG_HARDLOCKUP_DETECTOR) +void watchdog_update_hrtimer_threshold(u64 period); +#else +static inline void watchdog_update_hrtimer_threshold(u64 period) { } +#endif + extern bool is_hardlockup(void); struct ctl_table; extern int proc_watchdog(struct ctl_table *, int , diff --git a/kernel/watchdog.c b/kernel/watchdog.c index 06d3389bca0d..f5d52024f6b7 100644 --- a/kernel/watchdog.c +++ b/kernel/watchdog.c @@ -240,6 +240,7 @@ static void set_sample_period(void) * hardlockup detector generates a warning */ sample_period = get_softlockup_thresh() * ((u64)NSEC_PER_SEC / 5); + watchdog_update_hrtimer_threshold(sample_period); } /* Commands for resetting the watchdog */ diff --git a/kernel/watchdog_hld.c b/kernel/watchdog_hld.c index 295a0d84934c..3a09ea1b1d3d 100644 --- a/kernel/watchdog_hld.c +++ b/kernel/watchdog_hld.c @@ -37,6 +37,62 @@ void arch_touch_nmi_watchdog(void) } EXPORT_SYMBOL(arch_touch_nmi_watchdog); +#ifdef CONFIG_HARDLOCKUP_CHECK_TIMESTAMP +static DEFINE_PER_CPU(ktime_t, last_timestamp); +static DEFINE_PER_CPU(unsigned int, nmi_rearmed); +static ktime_t watchdog_hrtimer_sample_threshold __read_mostly; + +void watchdog_update_hrtimer_threshold(u64 period) +{ + /* + * The hrtimer runs with a period of (watchdog_threshold * 2) / 5 + * + * So it runs effectively with 2.5 times the rate of the NMI + * watchdog. That means the hrtimer should fire 2-3 times before + * the NMI watchdog expires. The NMI watchdog on x86 is based on + * unhalted CPU cycles, so if Turbo-Mode is enabled the CPU cycles + * might run way faster than expected and the NMI fires in a + * smaller period than the one deduced from the nominal CPU + * frequency. Depending on the Turbo-Mode factor this might be fast + * enough to get the NMI period smaller than the hrtimer watchdog + * period and trigger false positives. + * + * The sample threshold is used to check in the NMI handler whether + * the minimum time between two NMI samples has elapsed. That + * prevents false positives. + * + * Set this to 4/5 of the actual watchdog threshold period so the + * hrtimer is guaranteed to fire at least once within the real + * watchdog threshold. + */ + watchdog_hrtimer_sample_threshold = period * 2; +} + +static bool watchdog_check_timestamp(void) +{ + ktime_t delta, now = ktime_get_mono_fast_ns(); + + delta = now - __this_cpu_read(last_timestamp); + if (delta < watchdog_hrtimer_sample_threshold) { + /* + * If ktime is jiffies based, a stalled timer would prevent + * jiffies from being incremented and the filter would look + * at a stale timestamp and never trigger. + */ + if (__this_cpu_inc_return(nmi_rearmed) < 10) + return false; + } + __this_cpu_write(nmi_rearmed, 0); + __this_cpu_write(last_timestamp, now); + return true; +} +#else +static inline bool watchdog_check_timestamp(void) +{ + return true; +} +#endif + static struct perf_event_attr wd_hw_attr = { .type = PERF_TYPE_HARDWARE, .config = PERF_COUNT_HW_CPU_CYCLES, @@ -61,6 +117,9 @@ static void watchdog_overflow_callback(struct perf_event *event, return; } + if (!watchdog_check_timestamp()) + return; + /* check for a hardlockup * This is done by making sure our timer interrupt * is incrementing. The timer interrupt should have diff --git a/lib/Kconfig.debug b/lib/Kconfig.debug index 98fe715522e8..c617b9d1d6cb 100644 --- a/lib/Kconfig.debug +++ b/lib/Kconfig.debug @@ -797,6 +797,13 @@ config HARDLOCKUP_DETECTOR_PERF bool select SOFTLOCKUP_DETECTOR +# +# Enables a timestamp based low pass filter to compensate for perf based +# hard lockup detection which runs too fast due to turbo modes. +# +config HARDLOCKUP_CHECK_TIMESTAMP + bool + # # arch/ can define HAVE_HARDLOCKUP_DETECTOR_ARCH to provide their own hard # lockup detector rather than the perf based detector. -- cgit v1.2.3 From dea3eb8b452e36cf2dd572b0a797915ccf452ae6 Mon Sep 17 00:00:00 2001 From: Stephan Mueller Date: Thu, 10 Aug 2017 08:06:18 +0200 Subject: lib/mpi: kunmap after finishing accessing buffer Using sg_miter_start and sg_miter_next, the buffer of an SG is kmap'ed to *buff. The current code calls sg_miter_stop (and thus kunmap) on the SG entry before the last access of *buff. The patch moves the sg_miter_stop call after the last access to *buff to ensure that the memory pointed to by *buff is still mapped. Fixes: 4816c9406430 ("lib/mpi: Fix SG miter leak") Cc: Signed-off-by: Stephan Mueller Signed-off-by: Herbert Xu --- lib/mpi/mpicoder.c | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'lib') diff --git a/lib/mpi/mpicoder.c b/lib/mpi/mpicoder.c index 5a0f75a3bf01..eead4b339466 100644 --- a/lib/mpi/mpicoder.c +++ b/lib/mpi/mpicoder.c @@ -364,11 +364,11 @@ MPI mpi_read_raw_from_sgl(struct scatterlist *sgl, unsigned int nbytes) } miter.consumed = lzeros; - sg_miter_stop(&miter); nbytes -= lzeros; nbits = nbytes * 8; if (nbits > MAX_EXTERN_MPI_BITS) { + sg_miter_stop(&miter); pr_info("MPI: mpi too large (%u bits)\n", nbits); return NULL; } @@ -376,6 +376,8 @@ MPI mpi_read_raw_from_sgl(struct scatterlist *sgl, unsigned int nbytes) if (nbytes > 0) nbits -= count_leading_zeros(*buff) - (BITS_PER_LONG - 8); + sg_miter_stop(&miter); + nlimbs = DIV_ROUND_UP(nbytes, BYTES_PER_MPI_LIMB); val = mpi_alloc(nlimbs); if (!val) -- cgit v1.2.3 From e91498589746065e3ae95d9a00b068e525eec34f Mon Sep 17 00:00:00 2001 From: Peter Zijlstra Date: Wed, 23 Aug 2017 13:13:11 +0200 Subject: locking/lockdep/selftests: Add mixed read-write ABBA tests Currently lockdep has limited support for recursive readers, add a few mixed read-write ABBA selftests to show the extend of these limitations. [ 0.000000] ---------------------------------------------------------------------------- [ 0.000000] | spin |wlock |rlock |mutex | wsem | rsem | [ 0.000000] -------------------------------------------------------------------------- [ 0.000000] mixed read-lock/lock-write ABBA: |FAILED| | ok | [ 0.000000] mixed read-lock/lock-read ABBA: | ok | | ok | [ 0.000000] mixed write-lock/lock-write ABBA: | ok | | ok | This clearly illustrates the case where lockdep fails to find a deadlock. Signed-off-by: Peter Zijlstra (Intel) Cc: Linus Torvalds Cc: Peter Zijlstra Cc: Thomas Gleixner Cc: boqun.feng@gmail.com Cc: byungchul.park@lge.com Cc: david@fromorbit.com Cc: johannes@sipsolutions.net Cc: oleg@redhat.com Cc: tj@kernel.org Signed-off-by: Ingo Molnar --- lib/locking-selftest.c | 117 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 115 insertions(+), 2 deletions(-) (limited to 'lib') diff --git a/lib/locking-selftest.c b/lib/locking-selftest.c index 6f2b135dc5e8..3c7151a6cd98 100644 --- a/lib/locking-selftest.c +++ b/lib/locking-selftest.c @@ -362,6 +362,103 @@ static void rsem_AA3(void) RSL(X2); // this one should fail } +/* + * read_lock(A) + * spin_lock(B) + * spin_lock(B) + * write_lock(A) + */ +static void rlock_ABBA1(void) +{ + RL(X1); + L(Y1); + U(Y1); + RU(X1); + + L(Y1); + WL(X1); + WU(X1); + U(Y1); // should fail +} + +static void rwsem_ABBA1(void) +{ + RSL(X1); + ML(Y1); + MU(Y1); + RSU(X1); + + ML(Y1); + WSL(X1); + WSU(X1); + MU(Y1); // should fail +} + +/* + * read_lock(A) + * spin_lock(B) + * spin_lock(B) + * read_lock(A) + */ +static void rlock_ABBA2(void) +{ + RL(X1); + L(Y1); + U(Y1); + RU(X1); + + L(Y1); + RL(X1); + RU(X1); + U(Y1); // should NOT fail +} + +static void rwsem_ABBA2(void) +{ + RSL(X1); + ML(Y1); + MU(Y1); + RSU(X1); + + ML(Y1); + RSL(X1); + RSU(X1); + MU(Y1); // should fail +} + + +/* + * write_lock(A) + * spin_lock(B) + * spin_lock(B) + * write_lock(A) + */ +static void rlock_ABBA3(void) +{ + WL(X1); + L(Y1); + U(Y1); + WU(X1); + + L(Y1); + WL(X1); + WU(X1); + U(Y1); // should fail +} + +static void rwsem_ABBA3(void) +{ + WSL(X1); + ML(Y1); + MU(Y1); + WSU(X1); + + ML(Y1); + WSL(X1); + WSU(X1); + MU(Y1); // should fail +} + /* * ABBA deadlock: */ @@ -1056,8 +1153,6 @@ static void dotest(void (*testcase_fn)(void), int expected, int lockclass_mask) if (debug_locks != expected) { unexpected_testcase_failures++; pr_cont("FAILED|"); - - dump_stack(); } else { testcase_successes++; pr_cont(" ok |"); @@ -1933,6 +2028,24 @@ void locking_selftest(void) dotest(rsem_AA3, FAILURE, LOCKTYPE_RWSEM); pr_cont("\n"); + print_testname("mixed read-lock/lock-write ABBA"); + pr_cont(" |"); + dotest(rlock_ABBA1, FAILURE, LOCKTYPE_RWLOCK); + pr_cont(" |"); + dotest(rwsem_ABBA1, FAILURE, LOCKTYPE_RWSEM); + + print_testname("mixed read-lock/lock-read ABBA"); + pr_cont(" |"); + dotest(rlock_ABBA2, SUCCESS, LOCKTYPE_RWLOCK); + pr_cont(" |"); + dotest(rwsem_ABBA2, FAILURE, LOCKTYPE_RWSEM); + + print_testname("mixed write-lock/lock-write ABBA"); + pr_cont(" |"); + dotest(rlock_ABBA3, FAILURE, LOCKTYPE_RWLOCK); + pr_cont(" |"); + dotest(rwsem_ABBA3, FAILURE, LOCKTYPE_RWSEM); + printk(" --------------------------------------------------------------------------\n"); /* -- cgit v1.2.3 From b5e0fff19bef3a36631bce742e363d3a99590c3a Mon Sep 17 00:00:00 2001 From: Denys Vlasenko Date: Sat, 12 Aug 2017 19:43:46 +0200 Subject: lib/raid6: align AVX512 constants to 512 bits, not bytes Signed-off-by: Denys Vlasenko Cc: H. Peter Anvin Cc: mingo@redhat.com Cc: Jim Kukunas Cc: Fenghua Yu Cc: Megha Dey Cc: Gayatri Kammela Cc: x86@kernel.org Cc: linux-kernel@vger.kernel.org Signed-off-by: Shaohua Li --- lib/raid6/avx512.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'lib') diff --git a/lib/raid6/avx512.c b/lib/raid6/avx512.c index f524a7972006..46df7977b971 100644 --- a/lib/raid6/avx512.c +++ b/lib/raid6/avx512.c @@ -29,7 +29,7 @@ static const struct raid6_avx512_constants { u64 x1d[8]; -} raid6_avx512_constants __aligned(512) = { +} raid6_avx512_constants __aligned(512/8) = { { 0x1d1d1d1d1d1d1d1dULL, 0x1d1d1d1d1d1d1d1dULL, 0x1d1d1d1d1d1d1d1dULL, 0x1d1d1d1d1d1d1d1dULL, 0x1d1d1d1d1d1d1d1dULL, 0x1d1d1d1d1d1d1d1dULL, -- cgit v1.2.3 From d82fed75294229abc9d757f08a4817febae6c4f4 Mon Sep 17 00:00:00 2001 From: Peter Zijlstra Date: Mon, 28 Aug 2017 14:42:45 +0200 Subject: locking/lockdep/selftests: Fix mixed read-write ABBA tests Commit: e91498589746 ("locking/lockdep/selftests: Add mixed read-write ABBA tests") adds an explicit FAILURE to the locking selftest but overlooked the fact that this kills lockdep. Fudge the test to avoid this. Signed-off-by: Peter Zijlstra (Intel) Cc: Linus Torvalds Cc: Peter Zijlstra Cc: Thomas Gleixner Cc: hpa@zytor.com Link: http://lkml.kernel.org/r/20170828124245.xlo2yshxq2btgmuf@hirez.programming.kicks-ass.net Signed-off-by: Ingo Molnar --- lib/locking-selftest.c | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'lib') diff --git a/lib/locking-selftest.c b/lib/locking-selftest.c index 3c7151a6cd98..cd0b5c964bd0 100644 --- a/lib/locking-selftest.c +++ b/lib/locking-selftest.c @@ -2031,6 +2031,12 @@ void locking_selftest(void) print_testname("mixed read-lock/lock-write ABBA"); pr_cont(" |"); dotest(rlock_ABBA1, FAILURE, LOCKTYPE_RWLOCK); + /* + * Lockdep does indeed fail here, but there's nothing we can do about + * that now. Don't kill lockdep for it. + */ + unexpected_testcase_failures--; + pr_cont(" |"); dotest(rwsem_ABBA1, FAILURE, LOCKTYPE_RWSEM); -- cgit v1.2.3 From 388f79fda74fd3d8700ed5d899573ec58c2e0253 Mon Sep 17 00:00:00 2001 From: Chris Mi Date: Wed, 30 Aug 2017 02:31:57 -0400 Subject: idr: Add new APIs to support unsigned long The following new APIs are added: int idr_alloc_ext(struct idr *idr, void *ptr, unsigned long *index, unsigned long start, unsigned long end, gfp_t gfp); void *idr_remove_ext(struct idr *idr, unsigned long id); void *idr_find_ext(const struct idr *idr, unsigned long id); void *idr_replace_ext(struct idr *idr, void *ptr, unsigned long id); void *idr_get_next_ext(struct idr *idr, unsigned long *nextid); Signed-off-by: Chris Mi Signed-off-by: Jiri Pirko Signed-off-by: David S. Miller --- include/linux/idr.h | 69 ++++++++++++++++++++++++++++++++++++++++++++-- include/linux/radix-tree.h | 21 ++++++++++++-- lib/idr.c | 66 +++++++++++++++++++++++++------------------- lib/radix-tree.c | 6 ++-- 4 files changed, 125 insertions(+), 37 deletions(-) (limited to 'lib') diff --git a/include/linux/idr.h b/include/linux/idr.h index bf70b3ef0a07..7c3a365f7e12 100644 --- a/include/linux/idr.h +++ b/include/linux/idr.h @@ -80,19 +80,75 @@ static inline void idr_set_cursor(struct idr *idr, unsigned int val) */ void idr_preload(gfp_t gfp_mask); -int idr_alloc(struct idr *, void *entry, int start, int end, gfp_t); + +int idr_alloc_cmn(struct idr *idr, void *ptr, unsigned long *index, + unsigned long start, unsigned long end, gfp_t gfp, + bool ext); + +/** + * idr_alloc - allocate an id + * @idr: idr handle + * @ptr: pointer to be associated with the new id + * @start: the minimum id (inclusive) + * @end: the maximum id (exclusive) + * @gfp: memory allocation flags + * + * Allocates an unused ID in the range [start, end). Returns -ENOSPC + * if there are no unused IDs in that range. + * + * Note that @end is treated as max when <= 0. This is to always allow + * using @start + N as @end as long as N is inside integer range. + * + * Simultaneous modifications to the @idr are not allowed and should be + * prevented by the user, usually with a lock. idr_alloc() may be called + * concurrently with read-only accesses to the @idr, such as idr_find() and + * idr_for_each_entry(). + */ +static inline int idr_alloc(struct idr *idr, void *ptr, + int start, int end, gfp_t gfp) +{ + unsigned long id; + int ret; + + if (WARN_ON_ONCE(start < 0)) + return -EINVAL; + + ret = idr_alloc_cmn(idr, ptr, &id, start, end, gfp, false); + + if (ret) + return ret; + + return id; +} + +static inline int idr_alloc_ext(struct idr *idr, void *ptr, + unsigned long *index, + unsigned long start, + unsigned long end, + gfp_t gfp) +{ + return idr_alloc_cmn(idr, ptr, index, start, end, gfp, true); +} + int idr_alloc_cyclic(struct idr *, void *entry, int start, int end, gfp_t); int idr_for_each(const struct idr *, int (*fn)(int id, void *p, void *data), void *data); void *idr_get_next(struct idr *, int *nextid); +void *idr_get_next_ext(struct idr *idr, unsigned long *nextid); void *idr_replace(struct idr *, void *, int id); +void *idr_replace_ext(struct idr *idr, void *ptr, unsigned long id); void idr_destroy(struct idr *); -static inline void *idr_remove(struct idr *idr, int id) +static inline void *idr_remove_ext(struct idr *idr, unsigned long id) { return radix_tree_delete_item(&idr->idr_rt, id, NULL); } +static inline void *idr_remove(struct idr *idr, int id) +{ + return idr_remove_ext(idr, id); +} + static inline void idr_init(struct idr *idr) { INIT_RADIX_TREE(&idr->idr_rt, IDR_RT_MARKER); @@ -128,11 +184,16 @@ static inline void idr_preload_end(void) * This function can be called under rcu_read_lock(), given that the leaf * pointers lifetimes are correctly managed. */ -static inline void *idr_find(const struct idr *idr, int id) +static inline void *idr_find_ext(const struct idr *idr, unsigned long id) { return radix_tree_lookup(&idr->idr_rt, id); } +static inline void *idr_find(const struct idr *idr, int id) +{ + return idr_find_ext(idr, id); +} + /** * idr_for_each_entry - iterate over an idr's elements of a given type * @idr: idr handle @@ -145,6 +206,8 @@ static inline void *idr_find(const struct idr *idr, int id) */ #define idr_for_each_entry(idr, entry, id) \ for (id = 0; ((entry) = idr_get_next(idr, &(id))) != NULL; ++id) +#define idr_for_each_entry_ext(idr, entry, id) \ + for (id = 0; ((entry) = idr_get_next_ext(idr, &(id))) != NULL; ++id) /** * idr_for_each_entry_continue - continue iteration over an idr's elements of a given type diff --git a/include/linux/radix-tree.h b/include/linux/radix-tree.h index 3e5735064b71..567ebb5eaab0 100644 --- a/include/linux/radix-tree.h +++ b/include/linux/radix-tree.h @@ -357,8 +357,25 @@ int radix_tree_split(struct radix_tree_root *, unsigned long index, unsigned new_order); int radix_tree_join(struct radix_tree_root *, unsigned long index, unsigned new_order, void *); -void __rcu **idr_get_free(struct radix_tree_root *, struct radix_tree_iter *, - gfp_t, int end); + +void __rcu **idr_get_free_cmn(struct radix_tree_root *root, + struct radix_tree_iter *iter, gfp_t gfp, + unsigned long max); +static inline void __rcu **idr_get_free(struct radix_tree_root *root, + struct radix_tree_iter *iter, + gfp_t gfp, + int end) +{ + return idr_get_free_cmn(root, iter, gfp, end > 0 ? end - 1 : INT_MAX); +} + +static inline void __rcu **idr_get_free_ext(struct radix_tree_root *root, + struct radix_tree_iter *iter, + gfp_t gfp, + unsigned long end) +{ + return idr_get_free_cmn(root, iter, gfp, end - 1); +} enum { RADIX_TREE_ITER_TAG_MASK = 0x0f, /* tag index in lower nybble */ diff --git a/lib/idr.c b/lib/idr.c index b13682bb0a1c..082778cf883e 100644 --- a/lib/idr.c +++ b/lib/idr.c @@ -7,45 +7,32 @@ DEFINE_PER_CPU(struct ida_bitmap *, ida_bitmap); static DEFINE_SPINLOCK(simple_ida_lock); -/** - * idr_alloc - allocate an id - * @idr: idr handle - * @ptr: pointer to be associated with the new id - * @start: the minimum id (inclusive) - * @end: the maximum id (exclusive) - * @gfp: memory allocation flags - * - * Allocates an unused ID in the range [start, end). Returns -ENOSPC - * if there are no unused IDs in that range. - * - * Note that @end is treated as max when <= 0. This is to always allow - * using @start + N as @end as long as N is inside integer range. - * - * Simultaneous modifications to the @idr are not allowed and should be - * prevented by the user, usually with a lock. idr_alloc() may be called - * concurrently with read-only accesses to the @idr, such as idr_find() and - * idr_for_each_entry(). - */ -int idr_alloc(struct idr *idr, void *ptr, int start, int end, gfp_t gfp) +int idr_alloc_cmn(struct idr *idr, void *ptr, unsigned long *index, + unsigned long start, unsigned long end, gfp_t gfp, + bool ext) { - void __rcu **slot; struct radix_tree_iter iter; + void __rcu **slot; - if (WARN_ON_ONCE(start < 0)) - return -EINVAL; if (WARN_ON_ONCE(radix_tree_is_internal_node(ptr))) return -EINVAL; radix_tree_iter_init(&iter, start); - slot = idr_get_free(&idr->idr_rt, &iter, gfp, end); + if (ext) + slot = idr_get_free_ext(&idr->idr_rt, &iter, gfp, end); + else + slot = idr_get_free(&idr->idr_rt, &iter, gfp, end); if (IS_ERR(slot)) return PTR_ERR(slot); radix_tree_iter_replace(&idr->idr_rt, &iter, slot, ptr); radix_tree_iter_tag_clear(&idr->idr_rt, &iter, IDR_FREE); - return iter.index; + + if (index) + *index = iter.index; + return 0; } -EXPORT_SYMBOL_GPL(idr_alloc); +EXPORT_SYMBOL_GPL(idr_alloc_cmn); /** * idr_alloc_cyclic - allocate new idr entry in a cyclical fashion @@ -134,6 +121,20 @@ void *idr_get_next(struct idr *idr, int *nextid) } EXPORT_SYMBOL(idr_get_next); +void *idr_get_next_ext(struct idr *idr, unsigned long *nextid) +{ + struct radix_tree_iter iter; + void __rcu **slot; + + slot = radix_tree_iter_find(&idr->idr_rt, &iter, *nextid); + if (!slot) + return NULL; + + *nextid = iter.index; + return rcu_dereference_raw(*slot); +} +EXPORT_SYMBOL(idr_get_next_ext); + /** * idr_replace - replace pointer for given id * @idr: idr handle @@ -149,13 +150,20 @@ EXPORT_SYMBOL(idr_get_next); * %-EINVAL indicates that @id or @ptr were not valid. */ void *idr_replace(struct idr *idr, void *ptr, int id) +{ + if (WARN_ON_ONCE(id < 0)) + return ERR_PTR(-EINVAL); + + return idr_replace_ext(idr, ptr, id); +} +EXPORT_SYMBOL(idr_replace); + +void *idr_replace_ext(struct idr *idr, void *ptr, unsigned long id) { struct radix_tree_node *node; void __rcu **slot = NULL; void *entry; - if (WARN_ON_ONCE(id < 0)) - return ERR_PTR(-EINVAL); if (WARN_ON_ONCE(radix_tree_is_internal_node(ptr))) return ERR_PTR(-EINVAL); @@ -167,7 +175,7 @@ void *idr_replace(struct idr *idr, void *ptr, int id) return entry; } -EXPORT_SYMBOL(idr_replace); +EXPORT_SYMBOL(idr_replace_ext); /** * DOC: IDA description diff --git a/lib/radix-tree.c b/lib/radix-tree.c index 898e87998417..c191b42d1213 100644 --- a/lib/radix-tree.c +++ b/lib/radix-tree.c @@ -2137,13 +2137,13 @@ int ida_pre_get(struct ida *ida, gfp_t gfp) } EXPORT_SYMBOL(ida_pre_get); -void __rcu **idr_get_free(struct radix_tree_root *root, - struct radix_tree_iter *iter, gfp_t gfp, int end) +void __rcu **idr_get_free_cmn(struct radix_tree_root *root, + struct radix_tree_iter *iter, gfp_t gfp, + unsigned long max) { struct radix_tree_node *node = NULL, *child; void __rcu **slot = (void __rcu **)&root->rnode; unsigned long maxindex, start = iter->next_index; - unsigned long max = end > 0 ? end - 1 : INT_MAX; unsigned int shift, offset = 0; grow: -- cgit v1.2.3 From 48c40c26fc0103d40eb402fcd8fcf800302c83ca Mon Sep 17 00:00:00 2001 From: Alexander Kuleshov Date: Wed, 23 Aug 2017 00:39:13 +0600 Subject: assoc_array: fix path to assoc_array documentation Signed-off-by: Alexander Kuleshov Signed-off-by: Jonathan Corbet --- lib/assoc_array.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'lib') diff --git a/lib/assoc_array.c b/lib/assoc_array.c index 59fd7c0b119c..155c55d8db5f 100644 --- a/lib/assoc_array.c +++ b/lib/assoc_array.c @@ -1,6 +1,6 @@ /* Generic associative array implementation. * - * See Documentation/assoc_array.txt for information. + * See Documentation/core-api/assoc_array.rst for information. * * Copyright (C) 2013 Red Hat, Inc. All Rights Reserved. * Written by David Howells (dhowells@redhat.com) -- cgit v1.2.3 From 5deb67f77a266010e2c10fb124b7516d0d258ce8 Mon Sep 17 00:00:00 2001 From: Robin Murphy Date: Thu, 31 Aug 2017 12:27:09 +0100 Subject: libnvdimm, nd_blk: remove mmio_flush_range() mmio_flush_range() suffers from a lack of clearly-defined semantics, and is somewhat ambiguous to port to other architectures where the scope of the writeback implied by "flush" and ordering might matter, but MMIO would tend to imply non-cacheable anyway. Per the rationale in 67a3e8fe9015 ("nd_blk: change aperture mapping from WC to WB"), the only existing use is actually to invalidate clean cache lines for ARCH_MEMREMAP_PMEM type mappings *without* writeback. Since the recent cleanup of the pmem API, that also now happens to be the exact purpose of arch_invalidate_pmem(), which would be a far more well-defined tool for the job. Rather than risk potentially inconsistent implementations of mmio_flush_range() for the sake of one callsite, streamline things by removing it entirely and instead move the ARCH_MEMREMAP_PMEM related definitions up to the libnvdimm level, so they can be shared by NFIT as well. This allows NFIT to be enabled for arm64. Signed-off-by: Robin Murphy Signed-off-by: Dan Williams --- arch/x86/Kconfig | 1 - arch/x86/include/asm/cacheflush.h | 2 -- drivers/acpi/nfit/Kconfig | 2 +- drivers/acpi/nfit/core.c | 2 +- drivers/nvdimm/pmem.h | 14 -------------- include/linux/libnvdimm.h | 15 +++++++++++++++ lib/Kconfig | 3 --- tools/testing/nvdimm/test/nfit.c | 4 ++-- 8 files changed, 19 insertions(+), 24 deletions(-) (limited to 'lib') diff --git a/arch/x86/Kconfig b/arch/x86/Kconfig index 781521b7cf9e..5f3b756ec0d3 100644 --- a/arch/x86/Kconfig +++ b/arch/x86/Kconfig @@ -53,7 +53,6 @@ config X86 select ARCH_HAS_FORTIFY_SOURCE select ARCH_HAS_GCOV_PROFILE_ALL select ARCH_HAS_KCOV if X86_64 - select ARCH_HAS_MMIO_FLUSH select ARCH_HAS_PMEM_API if X86_64 select ARCH_HAS_UACCESS_FLUSHCACHE if X86_64 select ARCH_HAS_SET_MEMORY diff --git a/arch/x86/include/asm/cacheflush.h b/arch/x86/include/asm/cacheflush.h index 8b4140f6724f..cb9a1af109b4 100644 --- a/arch/x86/include/asm/cacheflush.h +++ b/arch/x86/include/asm/cacheflush.h @@ -7,6 +7,4 @@ void clflush_cache_range(void *addr, unsigned int size); -#define mmio_flush_range(addr, size) clflush_cache_range(addr, size) - #endif /* _ASM_X86_CACHEFLUSH_H */ diff --git a/drivers/acpi/nfit/Kconfig b/drivers/acpi/nfit/Kconfig index 6d3351452ea2..929ba4da0b30 100644 --- a/drivers/acpi/nfit/Kconfig +++ b/drivers/acpi/nfit/Kconfig @@ -2,7 +2,7 @@ config ACPI_NFIT tristate "ACPI NVDIMM Firmware Interface Table (NFIT)" depends on PHYS_ADDR_T_64BIT depends on BLK_DEV - depends on ARCH_HAS_MMIO_FLUSH + depends on ARCH_HAS_PMEM_API select LIBNVDIMM help Infrastructure to probe ACPI 6 compliant platforms for diff --git a/drivers/acpi/nfit/core.c b/drivers/acpi/nfit/core.c index 03105648f9b1..c20124a6eb49 100644 --- a/drivers/acpi/nfit/core.c +++ b/drivers/acpi/nfit/core.c @@ -1964,7 +1964,7 @@ static int acpi_nfit_blk_single_io(struct nfit_blk *nfit_blk, memcpy_flushcache(mmio->addr.aperture + offset, iobuf + copied, c); else { if (nfit_blk->dimm_flags & NFIT_BLK_READ_FLUSH) - mmio_flush_range((void __force *) + arch_invalidate_pmem((void __force *) mmio->addr.aperture + offset, c); memcpy(iobuf + copied, mmio->addr.aperture + offset, c); diff --git a/drivers/nvdimm/pmem.h b/drivers/nvdimm/pmem.h index 5434321cad67..c5917f040fa7 100644 --- a/drivers/nvdimm/pmem.h +++ b/drivers/nvdimm/pmem.h @@ -5,20 +5,6 @@ #include #include -#ifdef CONFIG_ARCH_HAS_PMEM_API -#define ARCH_MEMREMAP_PMEM MEMREMAP_WB -void arch_wb_cache_pmem(void *addr, size_t size); -void arch_invalidate_pmem(void *addr, size_t size); -#else -#define ARCH_MEMREMAP_PMEM MEMREMAP_WT -static inline void arch_wb_cache_pmem(void *addr, size_t size) -{ -} -static inline void arch_invalidate_pmem(void *addr, size_t size) -{ -} -#endif - /* this definition is in it's own header for tools/testing/nvdimm to consume */ struct pmem_device { /* One contiguous memory region per device */ diff --git a/include/linux/libnvdimm.h b/include/linux/libnvdimm.h index 9b8d81a7b80e..3eaad2fbf284 100644 --- a/include/linux/libnvdimm.h +++ b/include/linux/libnvdimm.h @@ -174,4 +174,19 @@ u64 nd_fletcher64(void *addr, size_t len, bool le); void nvdimm_flush(struct nd_region *nd_region); int nvdimm_has_flush(struct nd_region *nd_region); int nvdimm_has_cache(struct nd_region *nd_region); + +#ifdef CONFIG_ARCH_HAS_PMEM_API +#define ARCH_MEMREMAP_PMEM MEMREMAP_WB +void arch_wb_cache_pmem(void *addr, size_t size); +void arch_invalidate_pmem(void *addr, size_t size); +#else +#define ARCH_MEMREMAP_PMEM MEMREMAP_WT +static inline void arch_wb_cache_pmem(void *addr, size_t size) +{ +} +static inline void arch_invalidate_pmem(void *addr, size_t size) +{ +} +#endif + #endif /* __LIBNVDIMM_H__ */ diff --git a/lib/Kconfig b/lib/Kconfig index 6762529ad9e4..527da69e3be1 100644 --- a/lib/Kconfig +++ b/lib/Kconfig @@ -559,9 +559,6 @@ config ARCH_HAS_PMEM_API config ARCH_HAS_UACCESS_FLUSHCACHE bool -config ARCH_HAS_MMIO_FLUSH - bool - config STACKDEPOT bool select STACKTRACE diff --git a/tools/testing/nvdimm/test/nfit.c b/tools/testing/nvdimm/test/nfit.c index 4c2fa98ef39d..d20791c3f499 100644 --- a/tools/testing/nvdimm/test/nfit.c +++ b/tools/testing/nvdimm/test/nfit.c @@ -1546,8 +1546,8 @@ static int nfit_test_blk_do_io(struct nd_blk_region *ndbr, resource_size_t dpa, else { memcpy(iobuf, mmio->addr.base + dpa, len); - /* give us some some coverage of the mmio_flush_range() API */ - mmio_flush_range(mmio->addr.base + dpa, len); + /* give us some some coverage of the arch_invalidate_pmem() API */ + arch_invalidate_pmem(mmio->addr.base + dpa, len); } nd_region_release_lane(nd_region, lane); -- cgit v1.2.3 From 3b3c4babd898715926d24ae10aa64778ace33aae Mon Sep 17 00:00:00 2001 From: Matthew Wilcox Date: Fri, 8 Sep 2017 16:13:48 -0700 Subject: lib/string.c: add multibyte memset functions Patch series "Multibyte memset variations", v4. A relatively common idiom we're missing is a function to fill an area of memory with a pattern which is larger than a single byte. I first noticed this with a zram patch which wanted to fill a page with an 'unsigned long' value. There turn out to be quite a few places in the kernel which can benefit from using an optimised function rather than a loop; sometimes text size, sometimes speed, and sometimes both. The optimised PowerPC version (not included here) improves performance by about 30% on POWER8 on just the raw memset_l(). Most of the extra lines of code come from the three testcases I added. This patch (of 8): memset16(), memset32() and memset64() are like memset(), but allow the caller to fill the destination with a value larger than a single byte. memset_l() and memset_p() allow the caller to use unsigned long and pointer values respectively. Link: http://lkml.kernel.org/r/20170720184539.31609-2-willy@infradead.org Signed-off-by: Matthew Wilcox Cc: "H. Peter Anvin" Cc: "James E.J. Bottomley" Cc: "Martin K. Petersen" Cc: David Miller Cc: Ingo Molnar Cc: Ivan Kokshaysky Cc: Matt Turner Cc: Michael Ellerman Cc: Minchan Kim Cc: Ralf Baechle Cc: Richard Henderson Cc: Russell King Cc: Sam Ravnborg Cc: Sergey Senozhatsky Cc: Thomas Gleixner Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- include/linux/string.h | 30 +++++++++++++++++++++++ lib/string.c | 66 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) (limited to 'lib') diff --git a/include/linux/string.h b/include/linux/string.h index a467e617eeb0..c8bdafffd2f0 100644 --- a/include/linux/string.h +++ b/include/linux/string.h @@ -99,6 +99,36 @@ extern __kernel_size_t strcspn(const char *,const char *); #ifndef __HAVE_ARCH_MEMSET extern void * memset(void *,int,__kernel_size_t); #endif + +#ifndef __HAVE_ARCH_MEMSET16 +extern void *memset16(uint16_t *, uint16_t, __kernel_size_t); +#endif + +#ifndef __HAVE_ARCH_MEMSET32 +extern void *memset32(uint32_t *, uint32_t, __kernel_size_t); +#endif + +#ifndef __HAVE_ARCH_MEMSET64 +extern void *memset64(uint64_t *, uint64_t, __kernel_size_t); +#endif + +static inline void *memset_l(unsigned long *p, unsigned long v, + __kernel_size_t n) +{ + if (BITS_PER_LONG == 32) + return memset32((uint32_t *)p, v, n); + else + return memset64((uint64_t *)p, v, n); +} + +static inline void *memset_p(void **p, void *v, __kernel_size_t n) +{ + if (BITS_PER_LONG == 32) + return memset32((uint32_t *)p, (uintptr_t)v, n); + else + return memset64((uint64_t *)p, (uintptr_t)v, n); +} + #ifndef __HAVE_ARCH_MEMCPY extern void * memcpy(void *,const void *,__kernel_size_t); #endif diff --git a/lib/string.c b/lib/string.c index ebbb99c775bd..198148bb61fd 100644 --- a/lib/string.c +++ b/lib/string.c @@ -723,6 +723,72 @@ void memzero_explicit(void *s, size_t count) } EXPORT_SYMBOL(memzero_explicit); +#ifndef __HAVE_ARCH_MEMSET16 +/** + * memset16() - Fill a memory area with a uint16_t + * @s: Pointer to the start of the area. + * @v: The value to fill the area with + * @count: The number of values to store + * + * Differs from memset() in that it fills with a uint16_t instead + * of a byte. Remember that @count is the number of uint16_ts to + * store, not the number of bytes. + */ +void *memset16(uint16_t *s, uint16_t v, size_t count) +{ + uint16_t *xs = s; + + while (count--) + *xs++ = v; + return s; +} +EXPORT_SYMBOL(memset16); +#endif + +#ifndef __HAVE_ARCH_MEMSET32 +/** + * memset32() - Fill a memory area with a uint32_t + * @s: Pointer to the start of the area. + * @v: The value to fill the area with + * @count: The number of values to store + * + * Differs from memset() in that it fills with a uint32_t instead + * of a byte. Remember that @count is the number of uint32_ts to + * store, not the number of bytes. + */ +void *memset32(uint32_t *s, uint32_t v, size_t count) +{ + uint32_t *xs = s; + + while (count--) + *xs++ = v; + return s; +} +EXPORT_SYMBOL(memset32); +#endif + +#ifndef __HAVE_ARCH_MEMSET64 +/** + * memset64() - Fill a memory area with a uint64_t + * @s: Pointer to the start of the area. + * @v: The value to fill the area with + * @count: The number of values to store + * + * Differs from memset() in that it fills with a uint64_t instead + * of a byte. Remember that @count is the number of uint64_ts to + * store, not the number of bytes. + */ +void *memset64(uint64_t *s, uint64_t v, size_t count) +{ + uint64_t *xs = s; + + while (count--) + *xs++ = v; + return s; +} +EXPORT_SYMBOL(memset64); +#endif + #ifndef __HAVE_ARCH_MEMCPY /** * memcpy - Copy one area of memory to another -- cgit v1.2.3 From 03270c13c5ffaa6ac76fe70d0b6929313ca73d86 Mon Sep 17 00:00:00 2001 From: Matthew Wilcox Date: Fri, 8 Sep 2017 16:13:52 -0700 Subject: lib/string.c: add testcases for memset16/32/64 [akpm@linux-foundation.org: minor tweaks] Link: http://lkml.kernel.org/r/20170720184539.31609-3-willy@infradead.org Signed-off-by: Matthew Wilcox Cc: "H. Peter Anvin" Cc: "James E.J. Bottomley" Cc: "Martin K. Petersen" Cc: David Miller Cc: Ingo Molnar Cc: Ivan Kokshaysky Cc: Matt Turner Cc: Michael Ellerman Cc: Minchan Kim Cc: Ralf Baechle Cc: Richard Henderson Cc: Russell King Cc: Sam Ravnborg Cc: Sergey Senozhatsky Cc: Thomas Gleixner Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- lib/Kconfig | 3 ++ lib/string.c | 129 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+) (limited to 'lib') diff --git a/lib/Kconfig b/lib/Kconfig index 6762529ad9e4..40b114a11d7c 100644 --- a/lib/Kconfig +++ b/lib/Kconfig @@ -575,4 +575,7 @@ config PARMAN config PRIME_NUMBERS tristate +config STRING_SELFTEST + bool "Test string functions" + endmenu diff --git a/lib/string.c b/lib/string.c index 198148bb61fd..abf6499e3915 100644 --- a/lib/string.c +++ b/lib/string.c @@ -1051,3 +1051,132 @@ void fortify_panic(const char *name) BUG(); } EXPORT_SYMBOL(fortify_panic); + +#ifdef CONFIG_STRING_SELFTEST +#include +#include + +static __init int memset16_selftest(void) +{ + unsigned i, j, k; + u16 v, *p = kmalloc(256 * 2 * 2, GFP_KERNEL); + + for (i = 0; i < 256; i++) { + for (j = 0; j < 256; j++) { + memset(p, 0xa1, 256 * 2 * sizeof(v)); + memset16(p + i, 0xb1b2, j); + for (k = 0; k < 512; k++) { + v = p[k]; + if (k < i) { + if (v != 0xa1a1) + goto fail; + } else if (k < i + j) { + if (v != 0xb1b2) + goto fail; + } else { + if (v != 0xa1a1) + goto fail; + } + } + } + } + +fail: + kfree(p); + if (i < 256) + return (i << 24) | (j << 16) | k; + return 0; +} + +static __init int memset32_selftest(void) +{ + unsigned i, j, k; + u32 v, *p = kmalloc(256 * 2 * 4, GFP_KERNEL); + + for (i = 0; i < 256; i++) { + for (j = 0; j < 256; j++) { + memset(p, 0xa1, 256 * 2 * sizeof(v)); + memset32(p + i, 0xb1b2b3b4, j); + for (k = 0; k < 512; k++) { + v = p[k]; + if (k < i) { + if (v != 0xa1a1a1a1) + goto fail; + } else if (k < i + j) { + if (v != 0xb1b2b3b4) + goto fail; + } else { + if (v != 0xa1a1a1a1) + goto fail; + } + } + } + } + +fail: + kfree(p); + if (i < 256) + return (i << 24) | (j << 16) | k; + return 0; +} + +static __init int memset64_selftest(void) +{ + unsigned i, j, k; + u64 v, *p = kmalloc(256 * 2 * 8, GFP_KERNEL); + + for (i = 0; i < 256; i++) { + for (j = 0; j < 256; j++) { + memset(p, 0xa1, 256 * 2 * sizeof(v)); + memset64(p + i, 0xb1b2b3b4b5b6b7b8ULL, j); + for (k = 0; k < 512; k++) { + v = p[k]; + if (k < i) { + if (v != 0xa1a1a1a1a1a1a1a1ULL) + goto fail; + } else if (k < i + j) { + if (v != 0xb1b2b3b4b5b6b7b8ULL) + goto fail; + } else { + if (v != 0xa1a1a1a1a1a1a1a1ULL) + goto fail; + } + } + } + } + +fail: + kfree(p); + if (i < 256) + return (i << 24) | (j << 16) | k; + return 0; +} + +static __init int string_selftest_init(void) +{ + int test, subtest; + + test = 1; + subtest = memset16_selftest(); + if (subtest) + goto fail; + + test = 2; + subtest = memset32_selftest(); + if (subtest) + goto fail; + + test = 3; + subtest = memset64_selftest(); + if (subtest) + goto fail; + + pr_info("String selftests succeeded\n"); + return 0; +fail: + pr_crit("String selftest failure %d.%08x\n", test, subtest); + return 0; +} + +module_init(string_selftest_init); +#endif /* CONFIG_STRING_SELFTEST */ -- cgit v1.2.3 From cd9e61ed1eebbcd5dfad59475d41ec58d9b64b6a Mon Sep 17 00:00:00 2001 From: Davidlohr Bueso Date: Fri, 8 Sep 2017 16:14:36 -0700 Subject: rbtree: cache leftmost node internally Patch series "rbtree: Cache leftmost node internally", v4. A series to extending rbtrees to internally cache the leftmost node such that we can have fast overlap check optimization for all interval tree users[1]. The benefits of this series are that: (i) Unify users that do internal leftmost node caching. (ii) Optimize all interval tree users. (iii) Convert at least two new users (epoll and procfs) to the new interface. This patch (of 16): Red-black tree semantics imply that nodes with smaller or greater (or equal for duplicates) keys always be to the left and right, respectively. For the kernel this is extremely evident when considering our rb_first() semantics. Enabling lookups for the smallest node in the tree in O(1) can save a good chunk of cycles in not having to walk down the tree each time. To this end there are a few core users that explicitly do this, such as the scheduler and rtmutexes. There is also the desire for interval trees to have this optimization allowing faster overlap checking. This patch introduces a new 'struct rb_root_cached' which is just the root with a cached pointer to the leftmost node. The reason why the regular rb_root was not extended instead of adding a new structure was that this allows the user to have the choice between memory footprint and actual tree performance. The new wrappers on top of the regular rb_root calls are: - rb_first_cached(cached_root) -- which is a fast replacement for rb_first. - rb_insert_color_cached(node, cached_root, new) - rb_erase_cached(node, cached_root) In addition, augmented cached interfaces are also added for basic insertion and deletion operations; which becomes important for the interval tree changes. With the exception of the inserts, which adds a bool for updating the new leftmost, the interfaces are kept the same. To this end, porting rb users to the cached version becomes really trivial, and keeping current rbtree semantics for users that don't care about the optimization requires zero overhead. Link: http://lkml.kernel.org/r/20170719014603.19029-2-dave@stgolabs.net Signed-off-by: Davidlohr Bueso Reviewed-by: Jan Kara Acked-by: Peter Zijlstra (Intel) Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- Documentation/rbtree.txt | 33 +++++++++++++++++++++++++++++++++ include/linux/rbtree.h | 21 +++++++++++++++++++++ include/linux/rbtree_augmented.h | 33 ++++++++++++++++++++++++++++++--- lib/rbtree.c | 34 +++++++++++++++++++++++++++++----- 4 files changed, 113 insertions(+), 8 deletions(-) (limited to 'lib') diff --git a/Documentation/rbtree.txt b/Documentation/rbtree.txt index b8a8c70b0188..c42a21b99046 100644 --- a/Documentation/rbtree.txt +++ b/Documentation/rbtree.txt @@ -193,6 +193,39 @@ Example:: for (node = rb_first(&mytree); node; node = rb_next(node)) printk("key=%s\n", rb_entry(node, struct mytype, node)->keystring); +Cached rbtrees +-------------- + +Computing the leftmost (smallest) node is quite a common task for binary +search trees, such as for traversals or users relying on a the particular +order for their own logic. To this end, users can use 'struct rb_root_cached' +to optimize O(logN) rb_first() calls to a simple pointer fetch avoiding +potentially expensive tree iterations. This is done at negligible runtime +overhead for maintanence; albeit larger memory footprint. + +Similar to the rb_root structure, cached rbtrees are initialized to be +empty via: + + struct rb_root_cached mytree = RB_ROOT_CACHED; + +Cached rbtree is simply a regular rb_root with an extra pointer to cache the +leftmost node. This allows rb_root_cached to exist wherever rb_root does, +which permits augmented trees to be supported as well as only a few extra +interfaces: + + struct rb_node *rb_first_cached(struct rb_root_cached *tree); + void rb_insert_color_cached(struct rb_node *, struct rb_root_cached *, bool); + void rb_erase_cached(struct rb_node *node, struct rb_root_cached *); + +Both insert and erase calls have their respective counterpart of augmented +trees: + + void rb_insert_augmented_cached(struct rb_node *node, struct rb_root_cached *, + bool, struct rb_augment_callbacks *); + void rb_erase_augmented_cached(struct rb_node *, struct rb_root_cached *, + struct rb_augment_callbacks *); + + Support for Augmented rbtrees ----------------------------- diff --git a/include/linux/rbtree.h b/include/linux/rbtree.h index e585018498d5..d574361943ea 100644 --- a/include/linux/rbtree.h +++ b/include/linux/rbtree.h @@ -44,10 +44,25 @@ struct rb_root { struct rb_node *rb_node; }; +/* + * Leftmost-cached rbtrees. + * + * We do not cache the rightmost node based on footprint + * size vs number of potential users that could benefit + * from O(1) rb_last(). Just not worth it, users that want + * this feature can always implement the logic explicitly. + * Furthermore, users that want to cache both pointers may + * find it a bit asymmetric, but that's ok. + */ +struct rb_root_cached { + struct rb_root rb_root; + struct rb_node *rb_leftmost; +}; #define rb_parent(r) ((struct rb_node *)((r)->__rb_parent_color & ~3)) #define RB_ROOT (struct rb_root) { NULL, } +#define RB_ROOT_CACHED (struct rb_root_cached) { {NULL, }, NULL } #define rb_entry(ptr, type, member) container_of(ptr, type, member) #define RB_EMPTY_ROOT(root) (READ_ONCE((root)->rb_node) == NULL) @@ -69,6 +84,12 @@ extern struct rb_node *rb_prev(const struct rb_node *); extern struct rb_node *rb_first(const struct rb_root *); extern struct rb_node *rb_last(const struct rb_root *); +extern void rb_insert_color_cached(struct rb_node *, + struct rb_root_cached *, bool); +extern void rb_erase_cached(struct rb_node *node, struct rb_root_cached *); +/* Same as rb_first(), but O(1) */ +#define rb_first_cached(root) (root)->rb_leftmost + /* Postorder iteration - always visit the parent after its children */ extern struct rb_node *rb_first_postorder(const struct rb_root *); extern struct rb_node *rb_next_postorder(const struct rb_node *); diff --git a/include/linux/rbtree_augmented.h b/include/linux/rbtree_augmented.h index 9702b6e183bc..6bfd2b581f75 100644 --- a/include/linux/rbtree_augmented.h +++ b/include/linux/rbtree_augmented.h @@ -41,7 +41,9 @@ struct rb_augment_callbacks { void (*rotate)(struct rb_node *old, struct rb_node *new); }; -extern void __rb_insert_augmented(struct rb_node *node, struct rb_root *root, +extern void __rb_insert_augmented(struct rb_node *node, + struct rb_root *root, + bool newleft, struct rb_node **leftmost, void (*augment_rotate)(struct rb_node *old, struct rb_node *new)); /* * Fixup the rbtree and update the augmented information when rebalancing. @@ -57,7 +59,16 @@ static inline void rb_insert_augmented(struct rb_node *node, struct rb_root *root, const struct rb_augment_callbacks *augment) { - __rb_insert_augmented(node, root, augment->rotate); + __rb_insert_augmented(node, root, false, NULL, augment->rotate); +} + +static inline void +rb_insert_augmented_cached(struct rb_node *node, + struct rb_root_cached *root, bool newleft, + const struct rb_augment_callbacks *augment) +{ + __rb_insert_augmented(node, &root->rb_root, + newleft, &root->rb_leftmost, augment->rotate); } #define RB_DECLARE_CALLBACKS(rbstatic, rbname, rbstruct, rbfield, \ @@ -150,6 +161,7 @@ extern void __rb_erase_color(struct rb_node *parent, struct rb_root *root, static __always_inline struct rb_node * __rb_erase_augmented(struct rb_node *node, struct rb_root *root, + struct rb_node **leftmost, const struct rb_augment_callbacks *augment) { struct rb_node *child = node->rb_right; @@ -157,6 +169,9 @@ __rb_erase_augmented(struct rb_node *node, struct rb_root *root, struct rb_node *parent, *rebalance; unsigned long pc; + if (leftmost && node == *leftmost) + *leftmost = rb_next(node); + if (!tmp) { /* * Case 1: node to erase has no more than 1 child (easy!) @@ -256,9 +271,21 @@ static __always_inline void rb_erase_augmented(struct rb_node *node, struct rb_root *root, const struct rb_augment_callbacks *augment) { - struct rb_node *rebalance = __rb_erase_augmented(node, root, augment); + struct rb_node *rebalance = __rb_erase_augmented(node, root, + NULL, augment); if (rebalance) __rb_erase_color(rebalance, root, augment->rotate); } +static __always_inline void +rb_erase_augmented_cached(struct rb_node *node, struct rb_root_cached *root, + const struct rb_augment_callbacks *augment) +{ + struct rb_node *rebalance = __rb_erase_augmented(node, &root->rb_root, + &root->rb_leftmost, + augment); + if (rebalance) + __rb_erase_color(rebalance, &root->rb_root, augment->rotate); +} + #endif /* _LINUX_RBTREE_AUGMENTED_H */ diff --git a/lib/rbtree.c b/lib/rbtree.c index 4ba2828a67c0..d102d9d2ffaa 100644 --- a/lib/rbtree.c +++ b/lib/rbtree.c @@ -95,10 +95,14 @@ __rb_rotate_set_parents(struct rb_node *old, struct rb_node *new, static __always_inline void __rb_insert(struct rb_node *node, struct rb_root *root, + bool newleft, struct rb_node **leftmost, void (*augment_rotate)(struct rb_node *old, struct rb_node *new)) { struct rb_node *parent = rb_red_parent(node), *gparent, *tmp; + if (newleft) + *leftmost = node; + while (true) { /* * Loop invariant: node is red @@ -434,19 +438,38 @@ static const struct rb_augment_callbacks dummy_callbacks = { void rb_insert_color(struct rb_node *node, struct rb_root *root) { - __rb_insert(node, root, dummy_rotate); + __rb_insert(node, root, false, NULL, dummy_rotate); } EXPORT_SYMBOL(rb_insert_color); void rb_erase(struct rb_node *node, struct rb_root *root) { struct rb_node *rebalance; - rebalance = __rb_erase_augmented(node, root, &dummy_callbacks); + rebalance = __rb_erase_augmented(node, root, + NULL, &dummy_callbacks); if (rebalance) ____rb_erase_color(rebalance, root, dummy_rotate); } EXPORT_SYMBOL(rb_erase); +void rb_insert_color_cached(struct rb_node *node, + struct rb_root_cached *root, bool leftmost) +{ + __rb_insert(node, &root->rb_root, leftmost, + &root->rb_leftmost, dummy_rotate); +} +EXPORT_SYMBOL(rb_insert_color_cached); + +void rb_erase_cached(struct rb_node *node, struct rb_root_cached *root) +{ + struct rb_node *rebalance; + rebalance = __rb_erase_augmented(node, &root->rb_root, + &root->rb_leftmost, &dummy_callbacks); + if (rebalance) + ____rb_erase_color(rebalance, &root->rb_root, dummy_rotate); +} +EXPORT_SYMBOL(rb_erase_cached); + /* * Augmented rbtree manipulation functions. * @@ -455,9 +478,10 @@ EXPORT_SYMBOL(rb_erase); */ void __rb_insert_augmented(struct rb_node *node, struct rb_root *root, + bool newleft, struct rb_node **leftmost, void (*augment_rotate)(struct rb_node *old, struct rb_node *new)) { - __rb_insert(node, root, augment_rotate); + __rb_insert(node, root, newleft, leftmost, augment_rotate); } EXPORT_SYMBOL(__rb_insert_augmented); @@ -502,7 +526,7 @@ struct rb_node *rb_next(const struct rb_node *node) * as we can. */ if (node->rb_right) { - node = node->rb_right; + node = node->rb_right; while (node->rb_left) node=node->rb_left; return (struct rb_node *)node; @@ -534,7 +558,7 @@ struct rb_node *rb_prev(const struct rb_node *node) * as we can. */ if (node->rb_left) { - node = node->rb_left; + node = node->rb_left; while (node->rb_right) node=node->rb_right; return (struct rb_node *)node; -- cgit v1.2.3 From 2aadf7fc7df9e70c99786ffb8452ccdd83d49e59 Mon Sep 17 00:00:00 2001 From: Davidlohr Bueso Date: Fri, 8 Sep 2017 16:14:39 -0700 Subject: rbtree: optimize root-check during rebalancing loop The only times the nil-parent (root node) condition is true is when the node is the first in the tree, or after fixing rbtree rule #4 and the case 1 rebalancing made the node the root. Such conditions do not apply most of the time: (i) The common case in an rbtree is to have more than a single node, so this is only true for the first rb_insert(). (ii) While there is a chance only one first rotation is needed, cases where the node's uncle is black (cases 2,3) are more common as we can have the following scenarios during the rotation looping: case1 only, case1+1, case2+3, case1+2+3, case3 only, etc. This patch, therefore, adds an unlikely() optimization to this conditional. When profiling with CONFIG_PROFILE_ANNOTATED_BRANCHES, a kernel build shows that the incorrect rate is less than 15%, and for workloads that involve insert mostly trees overtime tend to have less than 2% incorrect rate. Link: http://lkml.kernel.org/r/20170719014603.19029-3-dave@stgolabs.net Signed-off-by: Davidlohr Bueso Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- lib/rbtree.c | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) (limited to 'lib') diff --git a/lib/rbtree.c b/lib/rbtree.c index d102d9d2ffaa..e7cce12f404f 100644 --- a/lib/rbtree.c +++ b/lib/rbtree.c @@ -105,16 +105,25 @@ __rb_insert(struct rb_node *node, struct rb_root *root, while (true) { /* - * Loop invariant: node is red - * - * If there is a black parent, we are done. - * Otherwise, take some corrective action as we don't - * want a red root or two consecutive red nodes. + * Loop invariant: node is red. */ - if (!parent) { + if (unlikely(!parent)) { + /* + * The inserted node is root. Either this is the + * first node, or we recursed at Case 1 below and + * are no longer violating 4). + */ rb_set_parent_color(node, NULL, RB_BLACK); break; - } else if (rb_is_black(parent)) + } + + /* + * If there is a black parent, we are done. + * Otherwise, take some corrective action as, + * per 4), we don't want a red root or two + * consecutive red nodes. + */ + if(rb_is_black(parent)) break; gparent = rb_red_parent(parent); -- cgit v1.2.3 From 35dc67d7d922b2c9a1adb006c7a0f370eeb5c114 Mon Sep 17 00:00:00 2001 From: Davidlohr Bueso Date: Fri, 8 Sep 2017 16:14:42 -0700 Subject: rbtree: add some additional comments for rebalancing cases While overall the code is very nicely commented, it might not be immediately obvious from the diagrams what is going on. Add a very brief summary of each case. Opposite cases where the node is the left child are left untouched. Link: http://lkml.kernel.org/r/20170719014603.19029-4-dave@stgolabs.net Signed-off-by: Davidlohr Bueso Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- lib/rbtree.c | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'lib') diff --git a/lib/rbtree.c b/lib/rbtree.c index e7cce12f404f..ba4a9d165f1b 100644 --- a/lib/rbtree.c +++ b/lib/rbtree.c @@ -132,7 +132,7 @@ __rb_insert(struct rb_node *node, struct rb_root *root, if (parent != tmp) { /* parent == gparent->rb_left */ if (tmp && rb_is_red(tmp)) { /* - * Case 1 - color flips + * Case 1 - node's uncle is red (color flips). * * G g * / \ / \ @@ -155,7 +155,8 @@ __rb_insert(struct rb_node *node, struct rb_root *root, tmp = parent->rb_right; if (node == tmp) { /* - * Case 2 - left rotate at parent + * Case 2 - node's uncle is black and node is + * the parent's right child (left rotate at parent). * * G G * / \ / \ @@ -179,7 +180,8 @@ __rb_insert(struct rb_node *node, struct rb_root *root, } /* - * Case 3 - right rotate at gparent + * Case 3 - node's uncle is black and node is + * the parent's left child (right rotate at gparent). * * G P * / \ / \ -- cgit v1.2.3 From 223f8911eace60c787f8767c25148b80ece9732a Mon Sep 17 00:00:00 2001 From: Davidlohr Bueso Date: Fri, 8 Sep 2017 16:14:46 -0700 Subject: lib/rbtree_test.c: make input module parameters Allows for more flexible debugging. Link: http://lkml.kernel.org/r/20170719014603.19029-5-dave@stgolabs.net Signed-off-by: Davidlohr Bueso Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- lib/rbtree_test.c | 55 ++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 34 insertions(+), 21 deletions(-) (limited to 'lib') diff --git a/lib/rbtree_test.c b/lib/rbtree_test.c index 8b3c9dc88262..e83331aa1b7f 100644 --- a/lib/rbtree_test.c +++ b/lib/rbtree_test.c @@ -1,11 +1,18 @@ #include +#include #include #include +#include #include -#define NODES 100 -#define PERF_LOOPS 100000 -#define CHECK_LOOPS 100 +#define __param(type, name, init, msg) \ + static type name = init; \ + module_param(name, type, 0444); \ + MODULE_PARM_DESC(name, msg); + +__param(int, nnodes, 100, "Number of nodes in the rb-tree"); +__param(int, perf_loops, 100000, "Number of iterations modifying the rb-tree"); +__param(int, check_loops, 100, "Number of iterations modifying and verifying the rb-tree"); struct test_node { u32 key; @@ -17,7 +24,7 @@ struct test_node { }; static struct rb_root root = RB_ROOT; -static struct test_node nodes[NODES]; +static struct test_node *nodes = NULL; static struct rnd_state rnd; @@ -95,7 +102,7 @@ static void erase_augmented(struct test_node *node, struct rb_root *root) static void init(void) { int i; - for (i = 0; i < NODES; i++) { + for (i = 0; i < nnodes; i++) { nodes[i].key = prandom_u32_state(&rnd); nodes[i].val = prandom_u32_state(&rnd); } @@ -177,6 +184,10 @@ static int __init rbtree_test_init(void) int i, j; cycles_t time1, time2, time; + nodes = kmalloc(nnodes * sizeof(*nodes), GFP_KERNEL); + if (!nodes) + return -ENOMEM; + printk(KERN_ALERT "rbtree testing"); prandom_seed_state(&rnd, 3141592653589793238ULL); @@ -184,27 +195,27 @@ static int __init rbtree_test_init(void) time1 = get_cycles(); - for (i = 0; i < PERF_LOOPS; i++) { - for (j = 0; j < NODES; j++) + for (i = 0; i < perf_loops; i++) { + for (j = 0; j < nnodes; j++) insert(nodes + j, &root); - for (j = 0; j < NODES; j++) + for (j = 0; j < nnodes; j++) erase(nodes + j, &root); } time2 = get_cycles(); time = time2 - time1; - time = div_u64(time, PERF_LOOPS); + time = div_u64(time, perf_loops); printk(" -> %llu cycles\n", (unsigned long long)time); - for (i = 0; i < CHECK_LOOPS; i++) { + for (i = 0; i < check_loops; i++) { init(); - for (j = 0; j < NODES; j++) { + for (j = 0; j < nnodes; j++) { check(j); insert(nodes + j, &root); } - for (j = 0; j < NODES; j++) { - check(NODES - j); + for (j = 0; j < nnodes; j++) { + check(nnodes - j); erase(nodes + j, &root); } check(0); @@ -216,32 +227,34 @@ static int __init rbtree_test_init(void) time1 = get_cycles(); - for (i = 0; i < PERF_LOOPS; i++) { - for (j = 0; j < NODES; j++) + for (i = 0; i < perf_loops; i++) { + for (j = 0; j < nnodes; j++) insert_augmented(nodes + j, &root); - for (j = 0; j < NODES; j++) + for (j = 0; j < nnodes; j++) erase_augmented(nodes + j, &root); } time2 = get_cycles(); time = time2 - time1; - time = div_u64(time, PERF_LOOPS); + time = div_u64(time, perf_loops); printk(" -> %llu cycles\n", (unsigned long long)time); - for (i = 0; i < CHECK_LOOPS; i++) { + for (i = 0; i < check_loops; i++) { init(); - for (j = 0; j < NODES; j++) { + for (j = 0; j < nnodes; j++) { check_augmented(j); insert_augmented(nodes + j, &root); } - for (j = 0; j < NODES; j++) { - check_augmented(NODES - j); + for (j = 0; j < nnodes; j++) { + check_augmented(nnodes - j); erase_augmented(nodes + j, &root); } check_augmented(0); } + kfree(nodes); + return -EAGAIN; /* Fail will directly unload the module */ } -- cgit v1.2.3 From 977bd8d5e1e61dc877c468e8937a4ab3094e53eb Mon Sep 17 00:00:00 2001 From: Davidlohr Bueso Date: Fri, 8 Sep 2017 16:14:49 -0700 Subject: lib/rbtree_test.c: add (inorder) traversal test This adds a second test for regular rb-tree testing in that there is no need to repeat it for the augmented flavor. Link: http://lkml.kernel.org/r/20170719014603.19029-6-dave@stgolabs.net Signed-off-by: Davidlohr Bueso Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- lib/rbtree_test.c | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) (limited to 'lib') diff --git a/lib/rbtree_test.c b/lib/rbtree_test.c index e83331aa1b7f..967999ff46db 100644 --- a/lib/rbtree_test.c +++ b/lib/rbtree_test.c @@ -183,6 +183,7 @@ static int __init rbtree_test_init(void) { int i, j; cycles_t time1, time2, time; + struct rb_node *node; nodes = kmalloc(nnodes * sizeof(*nodes), GFP_KERNEL); if (!nodes) @@ -206,8 +207,28 @@ static int __init rbtree_test_init(void) time = time2 - time1; time = div_u64(time, perf_loops); - printk(" -> %llu cycles\n", (unsigned long long)time); + printk(" -> test 1 (latency of nnodes insert+delete): %llu cycles\n", (unsigned long long)time); + for (i = 0; i < nnodes; i++) + insert(nodes + i, &root); + + time1 = get_cycles(); + + for (i = 0; i < perf_loops; i++) { + for (node = rb_first(&root); node; node = rb_next(node)) + ; + } + + time2 = get_cycles(); + time = time2 - time1; + + time = div_u64(time, perf_loops); + printk(" -> test 2 (latency of inorder traversal): %llu cycles\n", (unsigned long long)time); + + for (i = 0; i < nnodes; i++) + erase(nodes + i, &root); + + /* run checks */ for (i = 0; i < check_loops; i++) { init(); for (j = 0; j < nnodes; j++) { @@ -238,7 +259,7 @@ static int __init rbtree_test_init(void) time = time2 - time1; time = div_u64(time, perf_loops); - printk(" -> %llu cycles\n", (unsigned long long)time); + printk(" -> test 1 (latency of nnodes insert+delete): %llu cycles\n", (unsigned long long)time); for (i = 0; i < check_loops; i++) { init(); -- cgit v1.2.3 From b10d43f9898e0b56f6cf2375093dbd2c5db54486 Mon Sep 17 00:00:00 2001 From: Davidlohr Bueso Date: Fri, 8 Sep 2017 16:14:52 -0700 Subject: lib/rbtree_test.c: support rb_root_cached We can work with a single rb_root_cached root to test both cached and non-cached rbtrees. In addition, also add a test to measure latencies between rb_first and its fast counterpart. Link: http://lkml.kernel.org/r/20170719014603.19029-7-dave@stgolabs.net Signed-off-by: Davidlohr Bueso Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- lib/rbtree_test.c | 156 +++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 137 insertions(+), 19 deletions(-) (limited to 'lib') diff --git a/lib/rbtree_test.c b/lib/rbtree_test.c index 967999ff46db..191a238e5a9d 100644 --- a/lib/rbtree_test.c +++ b/lib/rbtree_test.c @@ -23,14 +23,14 @@ struct test_node { u32 augmented; }; -static struct rb_root root = RB_ROOT; +static struct rb_root_cached root = RB_ROOT_CACHED; static struct test_node *nodes = NULL; static struct rnd_state rnd; -static void insert(struct test_node *node, struct rb_root *root) +static void insert(struct test_node *node, struct rb_root_cached *root) { - struct rb_node **new = &root->rb_node, *parent = NULL; + struct rb_node **new = &root->rb_root.rb_node, *parent = NULL; u32 key = node->key; while (*new) { @@ -42,14 +42,40 @@ static void insert(struct test_node *node, struct rb_root *root) } rb_link_node(&node->rb, parent, new); - rb_insert_color(&node->rb, root); + rb_insert_color(&node->rb, &root->rb_root); } -static inline void erase(struct test_node *node, struct rb_root *root) +static void insert_cached(struct test_node *node, struct rb_root_cached *root) { - rb_erase(&node->rb, root); + struct rb_node **new = &root->rb_root.rb_node, *parent = NULL; + u32 key = node->key; + bool leftmost = true; + + while (*new) { + parent = *new; + if (key < rb_entry(parent, struct test_node, rb)->key) + new = &parent->rb_left; + else { + new = &parent->rb_right; + leftmost = false; + } + } + + rb_link_node(&node->rb, parent, new); + rb_insert_color_cached(&node->rb, root, leftmost); } +static inline void erase(struct test_node *node, struct rb_root_cached *root) +{ + rb_erase(&node->rb, &root->rb_root); +} + +static inline void erase_cached(struct test_node *node, struct rb_root_cached *root) +{ + rb_erase_cached(&node->rb, root); +} + + static inline u32 augment_recompute(struct test_node *node) { u32 max = node->val, child_augmented; @@ -71,9 +97,10 @@ static inline u32 augment_recompute(struct test_node *node) RB_DECLARE_CALLBACKS(static, augment_callbacks, struct test_node, rb, u32, augmented, augment_recompute) -static void insert_augmented(struct test_node *node, struct rb_root *root) +static void insert_augmented(struct test_node *node, + struct rb_root_cached *root) { - struct rb_node **new = &root->rb_node, *rb_parent = NULL; + struct rb_node **new = &root->rb_root.rb_node, *rb_parent = NULL; u32 key = node->key; u32 val = node->val; struct test_node *parent; @@ -91,12 +118,47 @@ static void insert_augmented(struct test_node *node, struct rb_root *root) node->augmented = val; rb_link_node(&node->rb, rb_parent, new); - rb_insert_augmented(&node->rb, root, &augment_callbacks); + rb_insert_augmented(&node->rb, &root->rb_root, &augment_callbacks); +} + +static void insert_augmented_cached(struct test_node *node, + struct rb_root_cached *root) +{ + struct rb_node **new = &root->rb_root.rb_node, *rb_parent = NULL; + u32 key = node->key; + u32 val = node->val; + struct test_node *parent; + bool leftmost = true; + + while (*new) { + rb_parent = *new; + parent = rb_entry(rb_parent, struct test_node, rb); + if (parent->augmented < val) + parent->augmented = val; + if (key < parent->key) + new = &parent->rb.rb_left; + else { + new = &parent->rb.rb_right; + leftmost = false; + } + } + + node->augmented = val; + rb_link_node(&node->rb, rb_parent, new); + rb_insert_augmented_cached(&node->rb, root, + leftmost, &augment_callbacks); +} + + +static void erase_augmented(struct test_node *node, struct rb_root_cached *root) +{ + rb_erase_augmented(&node->rb, &root->rb_root, &augment_callbacks); } -static void erase_augmented(struct test_node *node, struct rb_root *root) +static void erase_augmented_cached(struct test_node *node, + struct rb_root_cached *root) { - rb_erase_augmented(&node->rb, root, &augment_callbacks); + rb_erase_augmented_cached(&node->rb, root, &augment_callbacks); } static void init(void) @@ -125,7 +187,7 @@ static void check_postorder_foreach(int nr_nodes) { struct test_node *cur, *n; int count = 0; - rbtree_postorder_for_each_entry_safe(cur, n, &root, rb) + rbtree_postorder_for_each_entry_safe(cur, n, &root.rb_root, rb) count++; WARN_ON_ONCE(count != nr_nodes); @@ -135,7 +197,7 @@ static void check_postorder(int nr_nodes) { struct rb_node *rb; int count = 0; - for (rb = rb_first_postorder(&root); rb; rb = rb_next_postorder(rb)) + for (rb = rb_first_postorder(&root.rb_root); rb; rb = rb_next_postorder(rb)) count++; WARN_ON_ONCE(count != nr_nodes); @@ -147,7 +209,7 @@ static void check(int nr_nodes) int count = 0, blacks = 0; u32 prev_key = 0; - for (rb = rb_first(&root); rb; rb = rb_next(rb)) { + for (rb = rb_first(&root.rb_root); rb; rb = rb_next(rb)) { struct test_node *node = rb_entry(rb, struct test_node, rb); WARN_ON_ONCE(node->key < prev_key); WARN_ON_ONCE(is_red(rb) && @@ -162,7 +224,7 @@ static void check(int nr_nodes) } WARN_ON_ONCE(count != nr_nodes); - WARN_ON_ONCE(count < (1 << black_path_count(rb_last(&root))) - 1); + WARN_ON_ONCE(count < (1 << black_path_count(rb_last(&root.rb_root))) - 1); check_postorder(nr_nodes); check_postorder_foreach(nr_nodes); @@ -173,7 +235,7 @@ static void check_augmented(int nr_nodes) struct rb_node *rb; check(nr_nodes); - for (rb = rb_first(&root); rb; rb = rb_next(rb)) { + for (rb = rb_first(&root.rb_root); rb; rb = rb_next(rb)) { struct test_node *node = rb_entry(rb, struct test_node, rb); WARN_ON_ONCE(node->augmented != augment_recompute(node)); } @@ -207,7 +269,24 @@ static int __init rbtree_test_init(void) time = time2 - time1; time = div_u64(time, perf_loops); - printk(" -> test 1 (latency of nnodes insert+delete): %llu cycles\n", (unsigned long long)time); + printk(" -> test 1 (latency of nnodes insert+delete): %llu cycles\n", + (unsigned long long)time); + + time1 = get_cycles(); + + for (i = 0; i < perf_loops; i++) { + for (j = 0; j < nnodes; j++) + insert_cached(nodes + j, &root); + for (j = 0; j < nnodes; j++) + erase_cached(nodes + j, &root); + } + + time2 = get_cycles(); + time = time2 - time1; + + time = div_u64(time, perf_loops); + printk(" -> test 2 (latency of nnodes cached insert+delete): %llu cycles\n", + (unsigned long long)time); for (i = 0; i < nnodes; i++) insert(nodes + i, &root); @@ -215,7 +294,7 @@ static int __init rbtree_test_init(void) time1 = get_cycles(); for (i = 0; i < perf_loops; i++) { - for (node = rb_first(&root); node; node = rb_next(node)) + for (node = rb_first(&root.rb_root); node; node = rb_next(node)) ; } @@ -223,7 +302,31 @@ static int __init rbtree_test_init(void) time = time2 - time1; time = div_u64(time, perf_loops); - printk(" -> test 2 (latency of inorder traversal): %llu cycles\n", (unsigned long long)time); + printk(" -> test 3 (latency of inorder traversal): %llu cycles\n", + (unsigned long long)time); + + time1 = get_cycles(); + + for (i = 0; i < perf_loops; i++) + node = rb_first(&root.rb_root); + + time2 = get_cycles(); + time = time2 - time1; + + time = div_u64(time, perf_loops); + printk(" -> test 4 (latency to fetch first node)\n"); + printk(" non-cached: %llu cycles\n", (unsigned long long)time); + + time1 = get_cycles(); + + for (i = 0; i < perf_loops; i++) + node = rb_first_cached(&root); + + time2 = get_cycles(); + time = time2 - time1; + + time = div_u64(time, perf_loops); + printk(" cached: %llu cycles\n", (unsigned long long)time); for (i = 0; i < nnodes; i++) erase(nodes + i, &root); @@ -261,6 +364,21 @@ static int __init rbtree_test_init(void) time = div_u64(time, perf_loops); printk(" -> test 1 (latency of nnodes insert+delete): %llu cycles\n", (unsigned long long)time); + time1 = get_cycles(); + + for (i = 0; i < perf_loops; i++) { + for (j = 0; j < nnodes; j++) + insert_augmented_cached(nodes + j, &root); + for (j = 0; j < nnodes; j++) + erase_augmented_cached(nodes + j, &root); + } + + time2 = get_cycles(); + time = time2 - time1; + + time = div_u64(time, perf_loops); + printk(" -> test 2 (latency of nnodes cached insert+delete): %llu cycles\n", (unsigned long long)time); + for (i = 0; i < check_loops; i++) { init(); for (j = 0; j < nnodes; j++) { -- cgit v1.2.3 From f808c13fd3738948e10196496959871130612b61 Mon Sep 17 00:00:00 2001 From: Davidlohr Bueso Date: Fri, 8 Sep 2017 16:15:08 -0700 Subject: lib/interval_tree: fast overlap detection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Allow interval trees to quickly check for overlaps to avoid unnecesary tree lookups in interval_tree_iter_first(). As of this patch, all interval tree flavors will require using a 'rb_root_cached' such that we can have the leftmost node easily available. While most users will make use of this feature, those with special functions (in addition to the generic insert, delete, search calls) will avoid using the cached option as they can do funky things with insertions -- for example, vma_interval_tree_insert_after(). [jglisse@redhat.com: fix deadlock from typo vm_lock_anon_vma()] Link: http://lkml.kernel.org/r/20170808225719.20723-1-jglisse@redhat.com Link: http://lkml.kernel.org/r/20170719014603.19029-12-dave@stgolabs.net Signed-off-by: Davidlohr Bueso Signed-off-by: Jérôme Glisse Acked-by: Christian König Acked-by: Peter Zijlstra (Intel) Acked-by: Doug Ledford Acked-by: Michael S. Tsirkin Cc: David Airlie Cc: Jason Wang Cc: Christian Benvenuti Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- drivers/gpu/drm/amd/amdgpu/amdgpu_mn.c | 8 ++-- drivers/gpu/drm/amd/amdgpu/amdgpu_vm.c | 7 ++-- drivers/gpu/drm/amd/amdgpu/amdgpu_vm.h | 2 +- drivers/gpu/drm/drm_mm.c | 19 +++++---- drivers/gpu/drm/drm_vma_manager.c | 2 +- drivers/gpu/drm/i915/i915_gem_userptr.c | 6 +-- drivers/gpu/drm/radeon/radeon.h | 2 +- drivers/gpu/drm/radeon/radeon_mn.c | 8 ++-- drivers/gpu/drm/radeon/radeon_vm.c | 7 ++-- drivers/infiniband/core/umem_rbtree.c | 4 +- drivers/infiniband/core/uverbs_cmd.c | 2 +- drivers/infiniband/hw/hfi1/mmu_rb.c | 10 ++--- drivers/infiniband/hw/usnic/usnic_uiom.c | 6 +-- drivers/infiniband/hw/usnic/usnic_uiom.h | 2 +- .../infiniband/hw/usnic/usnic_uiom_interval_tree.c | 15 +++---- .../infiniband/hw/usnic/usnic_uiom_interval_tree.h | 12 +++--- drivers/vhost/vhost.c | 2 +- drivers/vhost/vhost.h | 2 +- fs/hugetlbfs/inode.c | 6 +-- fs/inode.c | 2 +- include/drm/drm_mm.h | 2 +- include/linux/fs.h | 4 +- include/linux/interval_tree.h | 8 ++-- include/linux/interval_tree_generic.h | 46 +++++++++++++++++----- include/linux/mm.h | 17 ++++---- include/linux/rmap.h | 4 +- include/rdma/ib_umem_odp.h | 11 ++++-- include/rdma/ib_verbs.h | 2 +- lib/interval_tree_test.c | 4 +- mm/interval_tree.c | 10 ++--- mm/memory.c | 4 +- mm/mmap.c | 10 ++--- mm/rmap.c | 4 +- 33 files changed, 145 insertions(+), 105 deletions(-) (limited to 'lib') diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_mn.c b/drivers/gpu/drm/amd/amdgpu/amdgpu_mn.c index e1cde6b80027..3b0f2ec6eec7 100644 --- a/drivers/gpu/drm/amd/amdgpu/amdgpu_mn.c +++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_mn.c @@ -51,7 +51,7 @@ struct amdgpu_mn { /* objects protected by lock */ struct mutex lock; - struct rb_root objects; + struct rb_root_cached objects; }; struct amdgpu_mn_node { @@ -76,8 +76,8 @@ static void amdgpu_mn_destroy(struct work_struct *work) mutex_lock(&adev->mn_lock); mutex_lock(&rmn->lock); hash_del(&rmn->node); - rbtree_postorder_for_each_entry_safe(node, next_node, &rmn->objects, - it.rb) { + rbtree_postorder_for_each_entry_safe(node, next_node, + &rmn->objects.rb_root, it.rb) { list_for_each_entry_safe(bo, next_bo, &node->bos, mn_list) { bo->mn = NULL; list_del_init(&bo->mn_list); @@ -221,7 +221,7 @@ static struct amdgpu_mn *amdgpu_mn_get(struct amdgpu_device *adev) rmn->mm = mm; rmn->mn.ops = &amdgpu_mn_ops; mutex_init(&rmn->lock); - rmn->objects = RB_ROOT; + rmn->objects = RB_ROOT_CACHED; r = __mmu_notifier_register(&rmn->mn, mm); if (r) diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_vm.c b/drivers/gpu/drm/amd/amdgpu/amdgpu_vm.c index 6b1343e5541d..b9a5a77eedaf 100644 --- a/drivers/gpu/drm/amd/amdgpu/amdgpu_vm.c +++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_vm.c @@ -2475,7 +2475,7 @@ int amdgpu_vm_init(struct amdgpu_device *adev, struct amdgpu_vm *vm, u64 flags; uint64_t init_pde_value = 0; - vm->va = RB_ROOT; + vm->va = RB_ROOT_CACHED; vm->client_id = atomic64_inc_return(&adev->vm_manager.client_counter); for (i = 0; i < AMDGPU_MAX_VMHUBS; i++) vm->reserved_vmid[i] = NULL; @@ -2596,10 +2596,11 @@ void amdgpu_vm_fini(struct amdgpu_device *adev, struct amdgpu_vm *vm) amd_sched_entity_fini(vm->entity.sched, &vm->entity); - if (!RB_EMPTY_ROOT(&vm->va)) { + if (!RB_EMPTY_ROOT(&vm->va.rb_root)) { dev_err(adev->dev, "still active bo inside vm\n"); } - rbtree_postorder_for_each_entry_safe(mapping, tmp, &vm->va, rb) { + rbtree_postorder_for_each_entry_safe(mapping, tmp, + &vm->va.rb_root, rb) { list_del(&mapping->list); amdgpu_vm_it_remove(mapping, &vm->va); kfree(mapping); diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_vm.h b/drivers/gpu/drm/amd/amdgpu/amdgpu_vm.h index ba6691b58ee7..6716355403ec 100644 --- a/drivers/gpu/drm/amd/amdgpu/amdgpu_vm.h +++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_vm.h @@ -118,7 +118,7 @@ struct amdgpu_vm_pt { struct amdgpu_vm { /* tree of virtual addresses mapped */ - struct rb_root va; + struct rb_root_cached va; /* protecting invalidated */ spinlock_t status_lock; diff --git a/drivers/gpu/drm/drm_mm.c b/drivers/gpu/drm/drm_mm.c index f794089d30ac..61a1c8ea74bc 100644 --- a/drivers/gpu/drm/drm_mm.c +++ b/drivers/gpu/drm/drm_mm.c @@ -169,7 +169,7 @@ INTERVAL_TREE_DEFINE(struct drm_mm_node, rb, struct drm_mm_node * __drm_mm_interval_first(const struct drm_mm *mm, u64 start, u64 last) { - return drm_mm_interval_tree_iter_first((struct rb_root *)&mm->interval_tree, + return drm_mm_interval_tree_iter_first((struct rb_root_cached *)&mm->interval_tree, start, last) ?: (struct drm_mm_node *)&mm->head_node; } EXPORT_SYMBOL(__drm_mm_interval_first); @@ -180,6 +180,7 @@ static void drm_mm_interval_tree_add_node(struct drm_mm_node *hole_node, struct drm_mm *mm = hole_node->mm; struct rb_node **link, *rb; struct drm_mm_node *parent; + bool leftmost = true; node->__subtree_last = LAST(node); @@ -196,9 +197,10 @@ static void drm_mm_interval_tree_add_node(struct drm_mm_node *hole_node, rb = &hole_node->rb; link = &hole_node->rb.rb_right; + leftmost = false; } else { rb = NULL; - link = &mm->interval_tree.rb_node; + link = &mm->interval_tree.rb_root.rb_node; } while (*link) { @@ -208,14 +210,15 @@ static void drm_mm_interval_tree_add_node(struct drm_mm_node *hole_node, parent->__subtree_last = node->__subtree_last; if (node->start < parent->start) link = &parent->rb.rb_left; - else + else { link = &parent->rb.rb_right; + leftmost = true; + } } rb_link_node(&node->rb, rb, link); - rb_insert_augmented(&node->rb, - &mm->interval_tree, - &drm_mm_interval_tree_augment); + rb_insert_augmented_cached(&node->rb, &mm->interval_tree, leftmost, + &drm_mm_interval_tree_augment); } #define RB_INSERT(root, member, expr) do { \ @@ -577,7 +580,7 @@ void drm_mm_replace_node(struct drm_mm_node *old, struct drm_mm_node *new) *new = *old; list_replace(&old->node_list, &new->node_list); - rb_replace_node(&old->rb, &new->rb, &old->mm->interval_tree); + rb_replace_node(&old->rb, &new->rb, &old->mm->interval_tree.rb_root); if (drm_mm_hole_follows(old)) { list_replace(&old->hole_stack, &new->hole_stack); @@ -863,7 +866,7 @@ void drm_mm_init(struct drm_mm *mm, u64 start, u64 size) mm->color_adjust = NULL; INIT_LIST_HEAD(&mm->hole_stack); - mm->interval_tree = RB_ROOT; + mm->interval_tree = RB_ROOT_CACHED; mm->holes_size = RB_ROOT; mm->holes_addr = RB_ROOT; diff --git a/drivers/gpu/drm/drm_vma_manager.c b/drivers/gpu/drm/drm_vma_manager.c index d9100b565198..28f1226576f8 100644 --- a/drivers/gpu/drm/drm_vma_manager.c +++ b/drivers/gpu/drm/drm_vma_manager.c @@ -147,7 +147,7 @@ struct drm_vma_offset_node *drm_vma_offset_lookup_locked(struct drm_vma_offset_m struct rb_node *iter; unsigned long offset; - iter = mgr->vm_addr_space_mm.interval_tree.rb_node; + iter = mgr->vm_addr_space_mm.interval_tree.rb_root.rb_node; best = NULL; while (likely(iter)) { diff --git a/drivers/gpu/drm/i915/i915_gem_userptr.c b/drivers/gpu/drm/i915/i915_gem_userptr.c index f152a38d7079..23fd18bd1b56 100644 --- a/drivers/gpu/drm/i915/i915_gem_userptr.c +++ b/drivers/gpu/drm/i915/i915_gem_userptr.c @@ -49,7 +49,7 @@ struct i915_mmu_notifier { spinlock_t lock; struct hlist_node node; struct mmu_notifier mn; - struct rb_root objects; + struct rb_root_cached objects; struct workqueue_struct *wq; }; @@ -123,7 +123,7 @@ static void i915_gem_userptr_mn_invalidate_range_start(struct mmu_notifier *_mn, struct interval_tree_node *it; LIST_HEAD(cancelled); - if (RB_EMPTY_ROOT(&mn->objects)) + if (RB_EMPTY_ROOT(&mn->objects.rb_root)) return; /* interval ranges are inclusive, but invalidate range is exclusive */ @@ -172,7 +172,7 @@ i915_mmu_notifier_create(struct mm_struct *mm) spin_lock_init(&mn->lock); mn->mn.ops = &i915_gem_userptr_notifier; - mn->objects = RB_ROOT; + mn->objects = RB_ROOT_CACHED; mn->wq = alloc_workqueue("i915-userptr-release", WQ_UNBOUND, 0); if (mn->wq == NULL) { kfree(mn); diff --git a/drivers/gpu/drm/radeon/radeon.h b/drivers/gpu/drm/radeon/radeon.h index ec63bc5e9de7..8cbaeec090c9 100644 --- a/drivers/gpu/drm/radeon/radeon.h +++ b/drivers/gpu/drm/radeon/radeon.h @@ -924,7 +924,7 @@ struct radeon_vm_id { struct radeon_vm { struct mutex mutex; - struct rb_root va; + struct rb_root_cached va; /* protecting invalidated and freed */ spinlock_t status_lock; diff --git a/drivers/gpu/drm/radeon/radeon_mn.c b/drivers/gpu/drm/radeon/radeon_mn.c index 896f2cf51e4e..1d62288b7ee3 100644 --- a/drivers/gpu/drm/radeon/radeon_mn.c +++ b/drivers/gpu/drm/radeon/radeon_mn.c @@ -50,7 +50,7 @@ struct radeon_mn { /* objects protected by lock */ struct mutex lock; - struct rb_root objects; + struct rb_root_cached objects; }; struct radeon_mn_node { @@ -75,8 +75,8 @@ static void radeon_mn_destroy(struct work_struct *work) mutex_lock(&rdev->mn_lock); mutex_lock(&rmn->lock); hash_del(&rmn->node); - rbtree_postorder_for_each_entry_safe(node, next_node, &rmn->objects, - it.rb) { + rbtree_postorder_for_each_entry_safe(node, next_node, + &rmn->objects.rb_root, it.rb) { interval_tree_remove(&node->it, &rmn->objects); list_for_each_entry_safe(bo, next_bo, &node->bos, mn_list) { @@ -205,7 +205,7 @@ static struct radeon_mn *radeon_mn_get(struct radeon_device *rdev) rmn->mm = mm; rmn->mn.ops = &radeon_mn_ops; mutex_init(&rmn->lock); - rmn->objects = RB_ROOT; + rmn->objects = RB_ROOT_CACHED; r = __mmu_notifier_register(&rmn->mn, mm); if (r) diff --git a/drivers/gpu/drm/radeon/radeon_vm.c b/drivers/gpu/drm/radeon/radeon_vm.c index 5e82b408d522..e5c0e635e371 100644 --- a/drivers/gpu/drm/radeon/radeon_vm.c +++ b/drivers/gpu/drm/radeon/radeon_vm.c @@ -1185,7 +1185,7 @@ int radeon_vm_init(struct radeon_device *rdev, struct radeon_vm *vm) vm->ids[i].last_id_use = NULL; } mutex_init(&vm->mutex); - vm->va = RB_ROOT; + vm->va = RB_ROOT_CACHED; spin_lock_init(&vm->status_lock); INIT_LIST_HEAD(&vm->invalidated); INIT_LIST_HEAD(&vm->freed); @@ -1232,10 +1232,11 @@ void radeon_vm_fini(struct radeon_device *rdev, struct radeon_vm *vm) struct radeon_bo_va *bo_va, *tmp; int i, r; - if (!RB_EMPTY_ROOT(&vm->va)) { + if (!RB_EMPTY_ROOT(&vm->va.rb_root)) { dev_err(rdev->dev, "still active bo inside vm\n"); } - rbtree_postorder_for_each_entry_safe(bo_va, tmp, &vm->va, it.rb) { + rbtree_postorder_for_each_entry_safe(bo_va, tmp, + &vm->va.rb_root, it.rb) { interval_tree_remove(&bo_va->it, &vm->va); r = radeon_bo_reserve(bo_va->bo, false); if (!r) { diff --git a/drivers/infiniband/core/umem_rbtree.c b/drivers/infiniband/core/umem_rbtree.c index d176597b4d78..fc801920e341 100644 --- a/drivers/infiniband/core/umem_rbtree.c +++ b/drivers/infiniband/core/umem_rbtree.c @@ -72,7 +72,7 @@ INTERVAL_TREE_DEFINE(struct umem_odp_node, rb, u64, __subtree_last, /* @last is not a part of the interval. See comment for function * node_last. */ -int rbt_ib_umem_for_each_in_range(struct rb_root *root, +int rbt_ib_umem_for_each_in_range(struct rb_root_cached *root, u64 start, u64 last, umem_call_back cb, void *cookie) @@ -95,7 +95,7 @@ int rbt_ib_umem_for_each_in_range(struct rb_root *root, } EXPORT_SYMBOL(rbt_ib_umem_for_each_in_range); -struct ib_umem_odp *rbt_ib_umem_lookup(struct rb_root *root, +struct ib_umem_odp *rbt_ib_umem_lookup(struct rb_root_cached *root, u64 addr, u64 length) { struct umem_odp_node *node; diff --git a/drivers/infiniband/core/uverbs_cmd.c b/drivers/infiniband/core/uverbs_cmd.c index e0cb99860934..4ab30d832ac5 100644 --- a/drivers/infiniband/core/uverbs_cmd.c +++ b/drivers/infiniband/core/uverbs_cmd.c @@ -118,7 +118,7 @@ ssize_t ib_uverbs_get_context(struct ib_uverbs_file *file, ucontext->closing = 0; #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING - ucontext->umem_tree = RB_ROOT; + ucontext->umem_tree = RB_ROOT_CACHED; init_rwsem(&ucontext->umem_rwsem); ucontext->odp_mrs_count = 0; INIT_LIST_HEAD(&ucontext->no_private_counters); diff --git a/drivers/infiniband/hw/hfi1/mmu_rb.c b/drivers/infiniband/hw/hfi1/mmu_rb.c index 2f0d285dc278..175002c046ed 100644 --- a/drivers/infiniband/hw/hfi1/mmu_rb.c +++ b/drivers/infiniband/hw/hfi1/mmu_rb.c @@ -54,7 +54,7 @@ struct mmu_rb_handler { struct mmu_notifier mn; - struct rb_root root; + struct rb_root_cached root; void *ops_arg; spinlock_t lock; /* protect the RB tree */ struct mmu_rb_ops *ops; @@ -108,7 +108,7 @@ int hfi1_mmu_rb_register(void *ops_arg, struct mm_struct *mm, if (!handlr) return -ENOMEM; - handlr->root = RB_ROOT; + handlr->root = RB_ROOT_CACHED; handlr->ops = ops; handlr->ops_arg = ops_arg; INIT_HLIST_NODE(&handlr->mn.hlist); @@ -149,9 +149,9 @@ void hfi1_mmu_rb_unregister(struct mmu_rb_handler *handler) INIT_LIST_HEAD(&del_list); spin_lock_irqsave(&handler->lock, flags); - while ((node = rb_first(&handler->root))) { + while ((node = rb_first_cached(&handler->root))) { rbnode = rb_entry(node, struct mmu_rb_node, node); - rb_erase(node, &handler->root); + rb_erase_cached(node, &handler->root); /* move from LRU list to delete list */ list_move(&rbnode->list, &del_list); } @@ -300,7 +300,7 @@ static void mmu_notifier_mem_invalidate(struct mmu_notifier *mn, { struct mmu_rb_handler *handler = container_of(mn, struct mmu_rb_handler, mn); - struct rb_root *root = &handler->root; + struct rb_root_cached *root = &handler->root; struct mmu_rb_node *node, *ptr = NULL; unsigned long flags; bool added = false; diff --git a/drivers/infiniband/hw/usnic/usnic_uiom.c b/drivers/infiniband/hw/usnic/usnic_uiom.c index c49db7c33979..4381c0a9a873 100644 --- a/drivers/infiniband/hw/usnic/usnic_uiom.c +++ b/drivers/infiniband/hw/usnic/usnic_uiom.c @@ -227,7 +227,7 @@ static void __usnic_uiom_reg_release(struct usnic_uiom_pd *pd, vpn_last = vpn_start + npages - 1; spin_lock(&pd->lock); - usnic_uiom_remove_interval(&pd->rb_root, vpn_start, + usnic_uiom_remove_interval(&pd->root, vpn_start, vpn_last, &rm_intervals); usnic_uiom_unmap_sorted_intervals(&rm_intervals, pd); @@ -379,7 +379,7 @@ struct usnic_uiom_reg *usnic_uiom_reg_get(struct usnic_uiom_pd *pd, err = usnic_uiom_get_intervals_diff(vpn_start, vpn_last, (writable) ? IOMMU_WRITE : 0, IOMMU_WRITE, - &pd->rb_root, + &pd->root, &sorted_diff_intervals); if (err) { usnic_err("Failed disjoint interval vpn [0x%lx,0x%lx] err %d\n", @@ -395,7 +395,7 @@ struct usnic_uiom_reg *usnic_uiom_reg_get(struct usnic_uiom_pd *pd, } - err = usnic_uiom_insert_interval(&pd->rb_root, vpn_start, vpn_last, + err = usnic_uiom_insert_interval(&pd->root, vpn_start, vpn_last, (writable) ? IOMMU_WRITE : 0); if (err) { usnic_err("Failed insert interval vpn [0x%lx,0x%lx] err %d\n", diff --git a/drivers/infiniband/hw/usnic/usnic_uiom.h b/drivers/infiniband/hw/usnic/usnic_uiom.h index 45ca7c1613a7..431efe4143f4 100644 --- a/drivers/infiniband/hw/usnic/usnic_uiom.h +++ b/drivers/infiniband/hw/usnic/usnic_uiom.h @@ -55,7 +55,7 @@ struct usnic_uiom_dev { struct usnic_uiom_pd { struct iommu_domain *domain; spinlock_t lock; - struct rb_root rb_root; + struct rb_root_cached root; struct list_head devs; int dev_cnt; }; diff --git a/drivers/infiniband/hw/usnic/usnic_uiom_interval_tree.c b/drivers/infiniband/hw/usnic/usnic_uiom_interval_tree.c index 42b4b4c4e452..d399523206c7 100644 --- a/drivers/infiniband/hw/usnic/usnic_uiom_interval_tree.c +++ b/drivers/infiniband/hw/usnic/usnic_uiom_interval_tree.c @@ -100,9 +100,9 @@ static int interval_cmp(void *priv, struct list_head *a, struct list_head *b) } static void -find_intervals_intersection_sorted(struct rb_root *root, unsigned long start, - unsigned long last, - struct list_head *list) +find_intervals_intersection_sorted(struct rb_root_cached *root, + unsigned long start, unsigned long last, + struct list_head *list) { struct usnic_uiom_interval_node *node; @@ -118,7 +118,7 @@ find_intervals_intersection_sorted(struct rb_root *root, unsigned long start, int usnic_uiom_get_intervals_diff(unsigned long start, unsigned long last, int flags, int flag_mask, - struct rb_root *root, + struct rb_root_cached *root, struct list_head *diff_set) { struct usnic_uiom_interval_node *interval, *tmp; @@ -175,7 +175,7 @@ void usnic_uiom_put_interval_set(struct list_head *intervals) kfree(interval); } -int usnic_uiom_insert_interval(struct rb_root *root, unsigned long start, +int usnic_uiom_insert_interval(struct rb_root_cached *root, unsigned long start, unsigned long last, int flags) { struct usnic_uiom_interval_node *interval, *tmp; @@ -246,8 +246,9 @@ err_out: return err; } -void usnic_uiom_remove_interval(struct rb_root *root, unsigned long start, - unsigned long last, struct list_head *removed) +void usnic_uiom_remove_interval(struct rb_root_cached *root, + unsigned long start, unsigned long last, + struct list_head *removed) { struct usnic_uiom_interval_node *interval; diff --git a/drivers/infiniband/hw/usnic/usnic_uiom_interval_tree.h b/drivers/infiniband/hw/usnic/usnic_uiom_interval_tree.h index c0b0b876ab90..1d7fc3226bca 100644 --- a/drivers/infiniband/hw/usnic/usnic_uiom_interval_tree.h +++ b/drivers/infiniband/hw/usnic/usnic_uiom_interval_tree.h @@ -48,12 +48,12 @@ struct usnic_uiom_interval_node { extern void usnic_uiom_interval_tree_insert(struct usnic_uiom_interval_node *node, - struct rb_root *root); + struct rb_root_cached *root); extern void usnic_uiom_interval_tree_remove(struct usnic_uiom_interval_node *node, - struct rb_root *root); + struct rb_root_cached *root); extern struct usnic_uiom_interval_node * -usnic_uiom_interval_tree_iter_first(struct rb_root *root, +usnic_uiom_interval_tree_iter_first(struct rb_root_cached *root, unsigned long start, unsigned long last); extern struct usnic_uiom_interval_node * @@ -63,7 +63,7 @@ usnic_uiom_interval_tree_iter_next(struct usnic_uiom_interval_node *node, * Inserts {start...last} into {root}. If there are overlaps, * nodes will be broken up and merged */ -int usnic_uiom_insert_interval(struct rb_root *root, +int usnic_uiom_insert_interval(struct rb_root_cached *root, unsigned long start, unsigned long last, int flags); /* @@ -71,7 +71,7 @@ int usnic_uiom_insert_interval(struct rb_root *root, * 'removed.' The caller is responsibile for freeing memory of nodes in * 'removed.' */ -void usnic_uiom_remove_interval(struct rb_root *root, +void usnic_uiom_remove_interval(struct rb_root_cached *root, unsigned long start, unsigned long last, struct list_head *removed); /* @@ -81,7 +81,7 @@ void usnic_uiom_remove_interval(struct rb_root *root, int usnic_uiom_get_intervals_diff(unsigned long start, unsigned long last, int flags, int flag_mask, - struct rb_root *root, + struct rb_root_cached *root, struct list_head *diff_set); /* Call this to free diff_set returned by usnic_uiom_get_intervals_diff */ void usnic_uiom_put_interval_set(struct list_head *intervals); diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 9cb3f722dce1..d6dbb28245e6 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -1271,7 +1271,7 @@ static struct vhost_umem *vhost_umem_alloc(void) if (!umem) return NULL; - umem->umem_tree = RB_ROOT; + umem->umem_tree = RB_ROOT_CACHED; umem->numem = 0; INIT_LIST_HEAD(&umem->umem_list); diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h index bb7c29b8b9fc..d59a9cc65f9d 100644 --- a/drivers/vhost/vhost.h +++ b/drivers/vhost/vhost.h @@ -71,7 +71,7 @@ struct vhost_umem_node { }; struct vhost_umem { - struct rb_root umem_tree; + struct rb_root_cached umem_tree; struct list_head umem_list; int numem; }; diff --git a/fs/hugetlbfs/inode.c b/fs/hugetlbfs/inode.c index 8c6f4b8f910f..59073e9f01a4 100644 --- a/fs/hugetlbfs/inode.c +++ b/fs/hugetlbfs/inode.c @@ -334,7 +334,7 @@ static void remove_huge_page(struct page *page) } static void -hugetlb_vmdelete_list(struct rb_root *root, pgoff_t start, pgoff_t end) +hugetlb_vmdelete_list(struct rb_root_cached *root, pgoff_t start, pgoff_t end) { struct vm_area_struct *vma; @@ -498,7 +498,7 @@ static int hugetlb_vmtruncate(struct inode *inode, loff_t offset) i_size_write(inode, offset); i_mmap_lock_write(mapping); - if (!RB_EMPTY_ROOT(&mapping->i_mmap)) + if (!RB_EMPTY_ROOT(&mapping->i_mmap.rb_root)) hugetlb_vmdelete_list(&mapping->i_mmap, pgoff, 0); i_mmap_unlock_write(mapping); remove_inode_hugepages(inode, offset, LLONG_MAX); @@ -523,7 +523,7 @@ static long hugetlbfs_punch_hole(struct inode *inode, loff_t offset, loff_t len) inode_lock(inode); i_mmap_lock_write(mapping); - if (!RB_EMPTY_ROOT(&mapping->i_mmap)) + if (!RB_EMPTY_ROOT(&mapping->i_mmap.rb_root)) hugetlb_vmdelete_list(&mapping->i_mmap, hole_start >> PAGE_SHIFT, hole_end >> PAGE_SHIFT); diff --git a/fs/inode.c b/fs/inode.c index 6a1626e0edaf..210054157a49 100644 --- a/fs/inode.c +++ b/fs/inode.c @@ -353,7 +353,7 @@ void address_space_init_once(struct address_space *mapping) init_rwsem(&mapping->i_mmap_rwsem); INIT_LIST_HEAD(&mapping->private_list); spin_lock_init(&mapping->private_lock); - mapping->i_mmap = RB_ROOT; + mapping->i_mmap = RB_ROOT_CACHED; } EXPORT_SYMBOL(address_space_init_once); diff --git a/include/drm/drm_mm.h b/include/drm/drm_mm.h index 49b292e98fec..8d10fc97801c 100644 --- a/include/drm/drm_mm.h +++ b/include/drm/drm_mm.h @@ -172,7 +172,7 @@ struct drm_mm { * according to the (increasing) start address of the memory node. */ struct drm_mm_node head_node; /* Keep an interval_tree for fast lookup of drm_mm_nodes by address. */ - struct rb_root interval_tree; + struct rb_root_cached interval_tree; struct rb_root holes_size; struct rb_root holes_addr; diff --git a/include/linux/fs.h b/include/linux/fs.h index 509434aaf5a4..6111976848ff 100644 --- a/include/linux/fs.h +++ b/include/linux/fs.h @@ -392,7 +392,7 @@ struct address_space { struct radix_tree_root page_tree; /* radix tree of all pages */ spinlock_t tree_lock; /* and lock protecting it */ atomic_t i_mmap_writable;/* count VM_SHARED mappings */ - struct rb_root i_mmap; /* tree of private and shared mappings */ + struct rb_root_cached i_mmap; /* tree of private and shared mappings */ struct rw_semaphore i_mmap_rwsem; /* protect tree, count, list */ /* Protected by tree_lock together with the radix tree */ unsigned long nrpages; /* number of total pages */ @@ -487,7 +487,7 @@ static inline void i_mmap_unlock_read(struct address_space *mapping) */ static inline int mapping_mapped(struct address_space *mapping) { - return !RB_EMPTY_ROOT(&mapping->i_mmap); + return !RB_EMPTY_ROOT(&mapping->i_mmap.rb_root); } /* diff --git a/include/linux/interval_tree.h b/include/linux/interval_tree.h index 724556aa3c95..202ee1283f4b 100644 --- a/include/linux/interval_tree.h +++ b/include/linux/interval_tree.h @@ -11,13 +11,15 @@ struct interval_tree_node { }; extern void -interval_tree_insert(struct interval_tree_node *node, struct rb_root *root); +interval_tree_insert(struct interval_tree_node *node, + struct rb_root_cached *root); extern void -interval_tree_remove(struct interval_tree_node *node, struct rb_root *root); +interval_tree_remove(struct interval_tree_node *node, + struct rb_root_cached *root); extern struct interval_tree_node * -interval_tree_iter_first(struct rb_root *root, +interval_tree_iter_first(struct rb_root_cached *root, unsigned long start, unsigned long last); extern struct interval_tree_node * diff --git a/include/linux/interval_tree_generic.h b/include/linux/interval_tree_generic.h index 58370e1862ad..f096423c8cbd 100644 --- a/include/linux/interval_tree_generic.h +++ b/include/linux/interval_tree_generic.h @@ -65,11 +65,13 @@ RB_DECLARE_CALLBACKS(static, ITPREFIX ## _augment, ITSTRUCT, ITRB, \ \ /* Insert / remove interval nodes from the tree */ \ \ -ITSTATIC void ITPREFIX ## _insert(ITSTRUCT *node, struct rb_root *root) \ +ITSTATIC void ITPREFIX ## _insert(ITSTRUCT *node, \ + struct rb_root_cached *root) \ { \ - struct rb_node **link = &root->rb_node, *rb_parent = NULL; \ + struct rb_node **link = &root->rb_root.rb_node, *rb_parent = NULL; \ ITTYPE start = ITSTART(node), last = ITLAST(node); \ ITSTRUCT *parent; \ + bool leftmost = true; \ \ while (*link) { \ rb_parent = *link; \ @@ -78,18 +80,22 @@ ITSTATIC void ITPREFIX ## _insert(ITSTRUCT *node, struct rb_root *root) \ parent->ITSUBTREE = last; \ if (start < ITSTART(parent)) \ link = &parent->ITRB.rb_left; \ - else \ + else { \ link = &parent->ITRB.rb_right; \ + leftmost = false; \ + } \ } \ \ node->ITSUBTREE = last; \ rb_link_node(&node->ITRB, rb_parent, link); \ - rb_insert_augmented(&node->ITRB, root, &ITPREFIX ## _augment); \ + rb_insert_augmented_cached(&node->ITRB, root, \ + leftmost, &ITPREFIX ## _augment); \ } \ \ -ITSTATIC void ITPREFIX ## _remove(ITSTRUCT *node, struct rb_root *root) \ +ITSTATIC void ITPREFIX ## _remove(ITSTRUCT *node, \ + struct rb_root_cached *root) \ { \ - rb_erase_augmented(&node->ITRB, root, &ITPREFIX ## _augment); \ + rb_erase_augmented_cached(&node->ITRB, root, &ITPREFIX ## _augment); \ } \ \ /* \ @@ -140,15 +146,35 @@ ITPREFIX ## _subtree_search(ITSTRUCT *node, ITTYPE start, ITTYPE last) \ } \ \ ITSTATIC ITSTRUCT * \ -ITPREFIX ## _iter_first(struct rb_root *root, ITTYPE start, ITTYPE last) \ +ITPREFIX ## _iter_first(struct rb_root_cached *root, \ + ITTYPE start, ITTYPE last) \ { \ - ITSTRUCT *node; \ + ITSTRUCT *node, *leftmost; \ \ - if (!root->rb_node) \ + if (!root->rb_root.rb_node) \ return NULL; \ - node = rb_entry(root->rb_node, ITSTRUCT, ITRB); \ + \ + /* \ + * Fastpath range intersection/overlap between A: [a0, a1] and \ + * B: [b0, b1] is given by: \ + * \ + * a0 <= b1 && b0 <= a1 \ + * \ + * ... where A holds the lock range and B holds the smallest \ + * 'start' and largest 'last' in the tree. For the later, we \ + * rely on the root node, which by augmented interval tree \ + * property, holds the largest value in its last-in-subtree. \ + * This allows mitigating some of the tree walk overhead for \ + * for non-intersecting ranges, maintained and consulted in O(1). \ + */ \ + node = rb_entry(root->rb_root.rb_node, ITSTRUCT, ITRB); \ if (node->ITSUBTREE < start) \ return NULL; \ + \ + leftmost = rb_entry(root->rb_leftmost, ITSTRUCT, ITRB); \ + if (ITSTART(leftmost) > last) \ + return NULL; \ + \ return ITPREFIX ## _subtree_search(node, start, last); \ } \ \ diff --git a/include/linux/mm.h b/include/linux/mm.h index 5195e272fc4a..f8c10d336e42 100644 --- a/include/linux/mm.h +++ b/include/linux/mm.h @@ -2034,13 +2034,13 @@ extern int nommu_shrink_inode_mappings(struct inode *, size_t, size_t); /* interval_tree.c */ void vma_interval_tree_insert(struct vm_area_struct *node, - struct rb_root *root); + struct rb_root_cached *root); void vma_interval_tree_insert_after(struct vm_area_struct *node, struct vm_area_struct *prev, - struct rb_root *root); + struct rb_root_cached *root); void vma_interval_tree_remove(struct vm_area_struct *node, - struct rb_root *root); -struct vm_area_struct *vma_interval_tree_iter_first(struct rb_root *root, + struct rb_root_cached *root); +struct vm_area_struct *vma_interval_tree_iter_first(struct rb_root_cached *root, unsigned long start, unsigned long last); struct vm_area_struct *vma_interval_tree_iter_next(struct vm_area_struct *node, unsigned long start, unsigned long last); @@ -2050,11 +2050,12 @@ struct vm_area_struct *vma_interval_tree_iter_next(struct vm_area_struct *node, vma; vma = vma_interval_tree_iter_next(vma, start, last)) void anon_vma_interval_tree_insert(struct anon_vma_chain *node, - struct rb_root *root); + struct rb_root_cached *root); void anon_vma_interval_tree_remove(struct anon_vma_chain *node, - struct rb_root *root); -struct anon_vma_chain *anon_vma_interval_tree_iter_first( - struct rb_root *root, unsigned long start, unsigned long last); + struct rb_root_cached *root); +struct anon_vma_chain * +anon_vma_interval_tree_iter_first(struct rb_root_cached *root, + unsigned long start, unsigned long last); struct anon_vma_chain *anon_vma_interval_tree_iter_next( struct anon_vma_chain *node, unsigned long start, unsigned long last); #ifdef CONFIG_DEBUG_VM_RB diff --git a/include/linux/rmap.h b/include/linux/rmap.h index f8ca2e74b819..733d3d8181e2 100644 --- a/include/linux/rmap.h +++ b/include/linux/rmap.h @@ -55,7 +55,9 @@ struct anon_vma { * is serialized by a system wide lock only visible to * mm_take_all_locks() (mm_all_locks_mutex). */ - struct rb_root rb_root; /* Interval tree of private "related" vmas */ + + /* Interval tree of private "related" vmas */ + struct rb_root_cached rb_root; }; /* diff --git a/include/rdma/ib_umem_odp.h b/include/rdma/ib_umem_odp.h index fb67554aabd6..5eb7f5bc8248 100644 --- a/include/rdma/ib_umem_odp.h +++ b/include/rdma/ib_umem_odp.h @@ -111,22 +111,25 @@ int ib_umem_odp_map_dma_pages(struct ib_umem *umem, u64 start_offset, u64 bcnt, void ib_umem_odp_unmap_dma_pages(struct ib_umem *umem, u64 start_offset, u64 bound); -void rbt_ib_umem_insert(struct umem_odp_node *node, struct rb_root *root); -void rbt_ib_umem_remove(struct umem_odp_node *node, struct rb_root *root); +void rbt_ib_umem_insert(struct umem_odp_node *node, + struct rb_root_cached *root); +void rbt_ib_umem_remove(struct umem_odp_node *node, + struct rb_root_cached *root); typedef int (*umem_call_back)(struct ib_umem *item, u64 start, u64 end, void *cookie); /* * Call the callback on each ib_umem in the range. Returns the logical or of * the return values of the functions called. */ -int rbt_ib_umem_for_each_in_range(struct rb_root *root, u64 start, u64 end, +int rbt_ib_umem_for_each_in_range(struct rb_root_cached *root, + u64 start, u64 end, umem_call_back cb, void *cookie); /* * Find first region intersecting with address range. * Return NULL if not found */ -struct ib_umem_odp *rbt_ib_umem_lookup(struct rb_root *root, +struct ib_umem_odp *rbt_ib_umem_lookup(struct rb_root_cached *root, u64 addr, u64 length); static inline int ib_umem_mmu_notifier_retry(struct ib_umem *item, diff --git a/include/rdma/ib_verbs.h b/include/rdma/ib_verbs.h index e6df68048517..bdb1279a415b 100644 --- a/include/rdma/ib_verbs.h +++ b/include/rdma/ib_verbs.h @@ -1457,7 +1457,7 @@ struct ib_ucontext { struct pid *tgid; #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING - struct rb_root umem_tree; + struct rb_root_cached umem_tree; /* * Protects .umem_rbroot and tree, as well as odp_mrs_count and * mmu notifiers registration. diff --git a/lib/interval_tree_test.c b/lib/interval_tree_test.c index df495fe81421..0e343fd29570 100644 --- a/lib/interval_tree_test.c +++ b/lib/interval_tree_test.c @@ -19,14 +19,14 @@ __param(bool, search_all, false, "Searches will iterate all nodes in the tree"); __param(uint, max_endpoint, ~0, "Largest value for the interval's endpoint"); -static struct rb_root root = RB_ROOT; +static struct rb_root_cached root = RB_ROOT_CACHED; static struct interval_tree_node *nodes = NULL; static u32 *queries = NULL; static struct rnd_state rnd; static inline unsigned long -search(struct rb_root *root, unsigned long start, unsigned long last) +search(struct rb_root_cached *root, unsigned long start, unsigned long last) { struct interval_tree_node *node; unsigned long results = 0; diff --git a/mm/interval_tree.c b/mm/interval_tree.c index f2c2492681bf..b47664358796 100644 --- a/mm/interval_tree.c +++ b/mm/interval_tree.c @@ -28,7 +28,7 @@ INTERVAL_TREE_DEFINE(struct vm_area_struct, shared.rb, /* Insert node immediately after prev in the interval tree */ void vma_interval_tree_insert_after(struct vm_area_struct *node, struct vm_area_struct *prev, - struct rb_root *root) + struct rb_root_cached *root) { struct rb_node **link; struct vm_area_struct *parent; @@ -55,7 +55,7 @@ void vma_interval_tree_insert_after(struct vm_area_struct *node, node->shared.rb_subtree_last = last; rb_link_node(&node->shared.rb, &parent->shared.rb, link); - rb_insert_augmented(&node->shared.rb, root, + rb_insert_augmented(&node->shared.rb, &root->rb_root, &vma_interval_tree_augment); } @@ -74,7 +74,7 @@ INTERVAL_TREE_DEFINE(struct anon_vma_chain, rb, unsigned long, rb_subtree_last, static inline, __anon_vma_interval_tree) void anon_vma_interval_tree_insert(struct anon_vma_chain *node, - struct rb_root *root) + struct rb_root_cached *root) { #ifdef CONFIG_DEBUG_VM_RB node->cached_vma_start = avc_start_pgoff(node); @@ -84,13 +84,13 @@ void anon_vma_interval_tree_insert(struct anon_vma_chain *node, } void anon_vma_interval_tree_remove(struct anon_vma_chain *node, - struct rb_root *root) + struct rb_root_cached *root) { __anon_vma_interval_tree_remove(node, root); } struct anon_vma_chain * -anon_vma_interval_tree_iter_first(struct rb_root *root, +anon_vma_interval_tree_iter_first(struct rb_root_cached *root, unsigned long first, unsigned long last) { return __anon_vma_interval_tree_iter_first(root, first, last); diff --git a/mm/memory.c b/mm/memory.c index 0bbc1d612a63..ec4e15494901 100644 --- a/mm/memory.c +++ b/mm/memory.c @@ -2761,7 +2761,7 @@ static void unmap_mapping_range_vma(struct vm_area_struct *vma, zap_page_range_single(vma, start_addr, end_addr - start_addr, details); } -static inline void unmap_mapping_range_tree(struct rb_root *root, +static inline void unmap_mapping_range_tree(struct rb_root_cached *root, struct zap_details *details) { struct vm_area_struct *vma; @@ -2825,7 +2825,7 @@ void unmap_mapping_range(struct address_space *mapping, details.last_index = ULONG_MAX; i_mmap_lock_write(mapping); - if (unlikely(!RB_EMPTY_ROOT(&mapping->i_mmap))) + if (unlikely(!RB_EMPTY_ROOT(&mapping->i_mmap.rb_root))) unmap_mapping_range_tree(&mapping->i_mmap, &details); i_mmap_unlock_write(mapping); } diff --git a/mm/mmap.c b/mm/mmap.c index 4c5981651407..680506faceae 100644 --- a/mm/mmap.c +++ b/mm/mmap.c @@ -685,7 +685,7 @@ int __vma_adjust(struct vm_area_struct *vma, unsigned long start, struct mm_struct *mm = vma->vm_mm; struct vm_area_struct *next = vma->vm_next, *orig_vma = vma; struct address_space *mapping = NULL; - struct rb_root *root = NULL; + struct rb_root_cached *root = NULL; struct anon_vma *anon_vma = NULL; struct file *file = vma->vm_file; bool start_changed = false, end_changed = false; @@ -3340,7 +3340,7 @@ static DEFINE_MUTEX(mm_all_locks_mutex); static void vm_lock_anon_vma(struct mm_struct *mm, struct anon_vma *anon_vma) { - if (!test_bit(0, (unsigned long *) &anon_vma->root->rb_root.rb_node)) { + if (!test_bit(0, (unsigned long *) &anon_vma->root->rb_root.rb_root.rb_node)) { /* * The LSB of head.next can't change from under us * because we hold the mm_all_locks_mutex. @@ -3356,7 +3356,7 @@ static void vm_lock_anon_vma(struct mm_struct *mm, struct anon_vma *anon_vma) * anon_vma->root->rwsem. */ if (__test_and_set_bit(0, (unsigned long *) - &anon_vma->root->rb_root.rb_node)) + &anon_vma->root->rb_root.rb_root.rb_node)) BUG(); } } @@ -3458,7 +3458,7 @@ out_unlock: static void vm_unlock_anon_vma(struct anon_vma *anon_vma) { - if (test_bit(0, (unsigned long *) &anon_vma->root->rb_root.rb_node)) { + if (test_bit(0, (unsigned long *) &anon_vma->root->rb_root.rb_root.rb_node)) { /* * The LSB of head.next can't change to 0 from under * us because we hold the mm_all_locks_mutex. @@ -3472,7 +3472,7 @@ static void vm_unlock_anon_vma(struct anon_vma *anon_vma) * anon_vma->root->rwsem. */ if (!__test_and_clear_bit(0, (unsigned long *) - &anon_vma->root->rb_root.rb_node)) + &anon_vma->root->rb_root.rb_root.rb_node)) BUG(); anon_vma_unlock_write(anon_vma); } diff --git a/mm/rmap.c b/mm/rmap.c index 0618cd85b862..b874c4761e84 100644 --- a/mm/rmap.c +++ b/mm/rmap.c @@ -391,7 +391,7 @@ void unlink_anon_vmas(struct vm_area_struct *vma) * Leave empty anon_vmas on the list - we'll need * to free them outside the lock. */ - if (RB_EMPTY_ROOT(&anon_vma->rb_root)) { + if (RB_EMPTY_ROOT(&anon_vma->rb_root.rb_root)) { anon_vma->parent->degree--; continue; } @@ -425,7 +425,7 @@ static void anon_vma_ctor(void *data) init_rwsem(&anon_vma->rwsem); atomic_set(&anon_vma->refcount, 0); - anon_vma->rb_root = RB_ROOT; + anon_vma->rb_root = RB_ROOT_CACHED; } void __init anon_vma_init(void) -- cgit v1.2.3 From 9888a588ea96ba2804f955bbc2667346719da887 Mon Sep 17 00:00:00 2001 From: Andy Shevchenko Date: Fri, 8 Sep 2017 16:15:28 -0700 Subject: lib/hexdump.c: return -EINVAL in case of error in hex2bin() In some cases caller would like to use error code directly without shadowing. -EINVAL feels a rightful code to return in case of error in hex2bin(). Link: http://lkml.kernel.org/r/20170731135510.68023-1-andriy.shevchenko@linux.intel.com Signed-off-by: Andy Shevchenko Cc: Arnd Bergmann Cc: Rasmus Villemoes Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- lib/hexdump.c | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'lib') diff --git a/lib/hexdump.c b/lib/hexdump.c index 992457b1284c..81b70ed37209 100644 --- a/lib/hexdump.c +++ b/lib/hexdump.c @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -42,7 +43,7 @@ EXPORT_SYMBOL(hex_to_bin); * @src: ascii hexadecimal string * @count: result length * - * Return 0 on success, -1 in case of bad input. + * Return 0 on success, -EINVAL in case of bad input. */ int hex2bin(u8 *dst, const char *src, size_t count) { @@ -51,7 +52,7 @@ int hex2bin(u8 *dst, const char *src, size_t count) int lo = hex_to_bin(*src++); if ((hi < 0) || (lo < 0)) - return -1; + return -EINVAL; *dst++ = (hi << 4) | lo; } -- cgit v1.2.3 From e4dace3615526fd66c86dd535ee4bc9e8c706e37 Mon Sep 17 00:00:00 2001 From: Florian Fainelli Date: Fri, 8 Sep 2017 16:15:31 -0700 Subject: lib: add test module for CONFIG_DEBUG_VIRTUAL Add a test module that allows testing that CONFIG_DEBUG_VIRTUAL works correctly, at least that it can catch invalid calls to virt_to_phys() against the non-linear kernel virtual address map. Link: http://lkml.kernel.org/r/20170808164035.26725-1-f.fainelli@gmail.com Signed-off-by: Florian Fainelli Cc: "Luis R. Rodriguez" Cc: Kees Cook Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- lib/Kconfig.debug | 11 +++++++++++ lib/Makefile | 1 + lib/test_debug_virtual.c | 49 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+) create mode 100644 lib/test_debug_virtual.c (limited to 'lib') diff --git a/lib/Kconfig.debug b/lib/Kconfig.debug index 7396f5044397..b19c491cbc4e 100644 --- a/lib/Kconfig.debug +++ b/lib/Kconfig.debug @@ -1930,6 +1930,17 @@ config TEST_KMOD If unsure, say N. +config TEST_DEBUG_VIRTUAL + tristate "Test CONFIG_DEBUG_VIRTUAL feature" + depends on DEBUG_VIRTUAL + help + Test the kernel's ability to detect incorrect calls to + virt_to_phys() done against the non-linear part of the + kernel's virtual address map. + + If unsure, say N. + + source "samples/Kconfig" source "lib/Kconfig.kgdb" diff --git a/lib/Makefile b/lib/Makefile index 40c18372b301..469ce5e24e4f 100644 --- a/lib/Makefile +++ b/lib/Makefile @@ -62,6 +62,7 @@ obj-$(CONFIG_TEST_BITMAP) += test_bitmap.o obj-$(CONFIG_TEST_UUID) += test_uuid.o obj-$(CONFIG_TEST_PARMAN) += test_parman.o obj-$(CONFIG_TEST_KMOD) += test_kmod.o +obj-$(CONFIG_TEST_DEBUG_VIRTUAL) += test_debug_virtual.o ifeq ($(CONFIG_DEBUG_KOBJECT),y) CFLAGS_kobject.o += -DDEBUG diff --git a/lib/test_debug_virtual.c b/lib/test_debug_virtual.c new file mode 100644 index 000000000000..b9cdeecc19dc --- /dev/null +++ b/lib/test_debug_virtual.c @@ -0,0 +1,49 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#ifdef CONFIG_MIPS +#include +#endif + +struct foo { + unsigned int bar; +}; + +struct foo *foo; + +static int __init test_debug_virtual_init(void) +{ + phys_addr_t pa; + void *va; + + va = (void *)VMALLOC_START; + pa = virt_to_phys(va); + + pr_info("PA: %pa for VA: 0x%lx\n", &pa, (unsigned long)va); + + foo = kzalloc(sizeof(*foo), GFP_KERNEL); + if (!foo) + return -ENOMEM; + + pa = virt_to_phys(foo); + va = foo; + pr_info("PA: %pa for VA: 0x%lx\n", &pa, (unsigned long)va); + + return 0; +} +module_init(test_debug_virtual_init); + +static void __exit test_debug_virtual_exit(void) +{ + kfree(foo); +} +module_exit(test_debug_virtual_exit); + +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Test module for CONFIG_DEBUG_VIRTUAL"); -- cgit v1.2.3 From 0a5ce0831d04382aa9e2420e33dff958ddade542 Mon Sep 17 00:00:00 2001 From: Yury Norov Date: Fri, 8 Sep 2017 16:15:34 -0700 Subject: lib/bitmap.c: make bitmap_parselist() thread-safe and much faster Current implementation of bitmap_parselist() uses a static variable to save local state while setting bits in the bitmap. It is obviously wrong if we assume execution in multiprocessor environment. Fortunately, it's possible to rewrite this portion of code to avoid using the static variable. It is also possible to set bits in the mask per-range with bitmap_set(), not per-bit, as it is implemented now, with set_bit(); which is way faster. The important side effect of this change is that setting bits in this function from now is not per-bit atomic and less memory-ordered. This is because set_bit() guarantees the order of memory accesses, while bitmap_set() does not. I think that it is the advantage of the new approach, because the bitmap_parselist() is intended to initialise bit arrays, and user should protect the whole bitmap during initialisation if needed. So protecting individual bits looks expensive and useless. Also, other range-oriented functions in lib/bitmap.c don't worry much about atomicity. With all that, setting 2k bits in map with the pattern like 0-2047:128/256 becomes ~50 times faster after applying the patch in my testing environment (arm64 hosted on qemu). The second patch of the series adds the test for bitmap_parselist(). It's not intended to cover all tricky cases, just to make sure that I didn't screw up during rework. Link: http://lkml.kernel.org/r/20170807225438.16161-1-ynorov@caviumnetworks.com Signed-off-by: Yury Norov Cc: Noam Camus Cc: Rasmus Villemoes Cc: Matthew Wilcox Cc: Mauro Carvalho Chehab Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- lib/bitmap.c | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) (limited to 'lib') diff --git a/lib/bitmap.c b/lib/bitmap.c index 9a532805364b..c82c61b66e16 100644 --- a/lib/bitmap.c +++ b/lib/bitmap.c @@ -513,7 +513,7 @@ static int __bitmap_parselist(const char *buf, unsigned int buflen, int nmaskbits) { unsigned int a, b, old_a, old_b; - unsigned int group_size, used_size; + unsigned int group_size, used_size, off; int c, old_c, totaldigits, ndigits; const char __user __force *ubuf = (const char __user __force *)buf; int at_start, in_range, in_partial_range; @@ -599,6 +599,8 @@ static int __bitmap_parselist(const char *buf, unsigned int buflen, a = old_a; b = old_b; old_a = old_b = 0; + } else { + used_size = group_size = b - a + 1; } /* if no digit is after '-', it's wrong*/ if (at_start && in_range) @@ -608,17 +610,9 @@ static int __bitmap_parselist(const char *buf, unsigned int buflen, if (b >= nmaskbits) return -ERANGE; while (a <= b) { - if (in_partial_range) { - static int pos_in_group = 1; - - if (pos_in_group <= used_size) - set_bit(a, maskp); - - if (a == b || ++pos_in_group > group_size) - pos_in_group = 1; - } else - set_bit(a, maskp); - a++; + off = min(b - a + 1, used_size); + bitmap_set(maskp, a, off); + a += group_size; } } while (buflen && c == ','); return 0; -- cgit v1.2.3 From 6df0d464dbcc1a55d10ded413beda167cb3c71b0 Mon Sep 17 00:00:00 2001 From: Yury Norov Date: Fri, 8 Sep 2017 16:15:38 -0700 Subject: lib/test_bitmap.c: add test for bitmap_parselist() Do some basic checks for bitmap_parselist(). [akpm@linux-foundation.org: fix printk warning] Link: http://lkml.kernel.org/r/20170807225438.16161-2-ynorov@caviumnetworks.com Signed-off-by: Yury Norov Cc: Noam Camus Cc: Rasmus Villemoes Cc: Matthew Wilcox Cc: Mauro Carvalho Chehab Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- lib/test_bitmap.c | 74 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) (limited to 'lib') diff --git a/lib/test_bitmap.c b/lib/test_bitmap.c index 2526a2975c51..431c97fb1aa0 100644 --- a/lib/test_bitmap.c +++ b/lib/test_bitmap.c @@ -165,6 +165,79 @@ static void __init test_zero_fill_copy(void) expect_eq_pbl("128-1023", bmap2, 1024); } +#define PARSE_TIME 0x1 + +struct test_bitmap_parselist{ + const int errno; + const char *in; + const unsigned long *expected; + const int nbits; + const int flags; +}; + +static const unsigned long exp[] = {1, 2, 0x0000ffff, 0xffff0000, 0x55555555, + 0xaaaaaaaa, 0x11111111, 0x22222222, 0xffffffff, + 0xfffffffe, 0x3333333311111111, 0xffffffff77777777}; +static const unsigned long exp2[] = {0x3333333311111111, 0xffffffff77777777}; + +static const struct test_bitmap_parselist parselist_tests[] __initconst = { + {0, "0", &exp[0], 8, 0}, + {0, "1", &exp[1], 8, 0}, + {0, "0-15", &exp[2], 32, 0}, + {0, "16-31", &exp[3], 32, 0}, + {0, "0-31:1/2", &exp[4], 32, 0}, + {0, "1-31:1/2", &exp[5], 32, 0}, + {0, "0-31:1/4", &exp[6], 32, 0}, + {0, "1-31:1/4", &exp[7], 32, 0}, + {0, "0-31:4/4", &exp[8], 32, 0}, + {0, "1-31:4/4", &exp[9], 32, 0}, + {0, "0-31:1/4,32-63:2/4", &exp[10], 64, 0}, + {0, "0-31:3/4,32-63:4/4", &exp[11], 64, 0}, + + {0, "0-31:1/4,32-63:2/4,64-95:3/4,96-127:4/4", exp2, 128, 0}, + + {0, "0-2047:128/256", NULL, 2048, PARSE_TIME}, + + {-EINVAL, "-1", NULL, 8, 0}, + {-EINVAL, "-0", NULL, 8, 0}, + {-EINVAL, "10-1", NULL, 8, 0}, + {-EINVAL, "0-31:10/1", NULL, 8, 0}, +}; + +static void __init test_bitmap_parselist(void) +{ + int i; + int err; + cycles_t cycles; + DECLARE_BITMAP(bmap, 2048); + + for (i = 0; i < ARRAY_SIZE(parselist_tests); i++) { +#define ptest parselist_tests[i] + + cycles = get_cycles(); + err = bitmap_parselist(ptest.in, bmap, ptest.nbits); + cycles = get_cycles() - cycles; + + if (err != ptest.errno) { + pr_err("test %d: input is %s, errno is %d, expected %d\n", + i, ptest.in, err, ptest.errno); + continue; + } + + if (!err && ptest.expected + && !__bitmap_equal(bmap, ptest.expected, ptest.nbits)) { + pr_err("test %d: input is %s, result is 0x%lx, expected 0x%lx\n", + i, ptest.in, bmap[0], *ptest.expected); + continue; + } + + if (ptest.flags & PARSE_TIME) + pr_err("test %d: input is '%s' OK, Time: %llu\n", + i, ptest.in, + (unsigned long long)cycles); + } +} + static void __init test_bitmap_u32_array_conversions(void) { DECLARE_BITMAP(bmap1, 1024); @@ -365,6 +438,7 @@ static int __init test_bitmap_init(void) { test_zero_fill_copy(); test_bitmap_u32_array_conversions(); + test_bitmap_parselist(); test_mem_optimisations(); if (failed_tests == 0) -- cgit v1.2.3 From 60ef690018b262ddcd0d51edf10e40710deb9c9f Mon Sep 17 00:00:00 2001 From: Yury Norov Date: Fri, 8 Sep 2017 16:15:41 -0700 Subject: bitmap: introduce BITMAP_FROM_U64() The macro is the compile-time analogue of bitmap_from_u64() with the same purpose: convert the 64-bit number to the properly ordered pair of 32-bit parts, suitable for filling the bitmap in 32-bit BE environment. Use it to make test_bitmap_parselist() correct for 32-bit BE ABIs. Tested on BE mips/qemu. [akpm@linux-foundation.org: tweak code comment] Link: http://lkml.kernel.org/r/20170810172916.24144-1-ynorov@caviumnetworks.com Signed-off-by: Yury Norov Cc: Noam Camus Cc: Rasmus Villemoes Cc: Matthew Wilcox Cc: Mauro Carvalho Chehab Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- include/linux/bitmap.h | 32 ++++++++++++++++++++++++++++++++ lib/test_bitmap.c | 47 ++++++++++++++++++++++++++++++++--------------- 2 files changed, 64 insertions(+), 15 deletions(-) (limited to 'lib') diff --git a/include/linux/bitmap.h b/include/linux/bitmap.h index 5797ca6fdfe2..700cf5f67118 100644 --- a/include/linux/bitmap.h +++ b/include/linux/bitmap.h @@ -360,6 +360,38 @@ static inline int bitmap_parse(const char *buf, unsigned int buflen, return __bitmap_parse(buf, buflen, 0, maskp, nmaskbits); } +/* + * BITMAP_FROM_U64() - Represent u64 value in the format suitable for bitmap. + * + * Linux bitmaps are internally arrays of unsigned longs, i.e. 32-bit + * integers in 32-bit environment, and 64-bit integers in 64-bit one. + * + * There are four combinations of endianness and length of the word in linux + * ABIs: LE64, BE64, LE32 and BE32. + * + * On 64-bit kernels 64-bit LE and BE numbers are naturally ordered in + * bitmaps and therefore don't require any special handling. + * + * On 32-bit kernels 32-bit LE ABI orders lo word of 64-bit number in memory + * prior to hi, and 32-bit BE orders hi word prior to lo. The bitmap on the + * other hand is represented as an array of 32-bit words and the position of + * bit N may therefore be calculated as: word #(N/32) and bit #(N%32) in that + * word. For example, bit #42 is located at 10th position of 2nd word. + * It matches 32-bit LE ABI, and we can simply let the compiler store 64-bit + * values in memory as it usually does. But for BE we need to swap hi and lo + * words manually. + * + * With all that, the macro BITMAP_FROM_U64() does explicit reordering of hi and + * lo parts of u64. For LE32 it does nothing, and for BE environment it swaps + * hi and lo words, as is expected by bitmap. + */ +#if __BITS_PER_LONG == 64 +#define BITMAP_FROM_U64(n) (n) +#else +#define BITMAP_FROM_U64(n) ((unsigned long) ((u64)(n) & ULONG_MAX)), \ + ((unsigned long) ((u64)(n) >> 32)) +#endif + /* * bitmap_from_u64 - Check and swap words within u64. * @mask: source bitmap diff --git a/lib/test_bitmap.c b/lib/test_bitmap.c index 431c97fb1aa0..599c6713f2a2 100644 --- a/lib/test_bitmap.c +++ b/lib/test_bitmap.c @@ -175,24 +175,41 @@ struct test_bitmap_parselist{ const int flags; }; -static const unsigned long exp[] = {1, 2, 0x0000ffff, 0xffff0000, 0x55555555, - 0xaaaaaaaa, 0x11111111, 0x22222222, 0xffffffff, - 0xfffffffe, 0x3333333311111111, 0xffffffff77777777}; -static const unsigned long exp2[] = {0x3333333311111111, 0xffffffff77777777}; +static const unsigned long exp[] __initconst = { + BITMAP_FROM_U64(1), + BITMAP_FROM_U64(2), + BITMAP_FROM_U64(0x0000ffff), + BITMAP_FROM_U64(0xffff0000), + BITMAP_FROM_U64(0x55555555), + BITMAP_FROM_U64(0xaaaaaaaa), + BITMAP_FROM_U64(0x11111111), + BITMAP_FROM_U64(0x22222222), + BITMAP_FROM_U64(0xffffffff), + BITMAP_FROM_U64(0xfffffffe), + BITMAP_FROM_U64(0x3333333311111111), + BITMAP_FROM_U64(0xffffffff77777777) +}; + +static const unsigned long exp2[] __initconst = { + BITMAP_FROM_U64(0x3333333311111111), + BITMAP_FROM_U64(0xffffffff77777777) +}; static const struct test_bitmap_parselist parselist_tests[] __initconst = { +#define step (sizeof(u64) / sizeof(unsigned long)) + {0, "0", &exp[0], 8, 0}, - {0, "1", &exp[1], 8, 0}, - {0, "0-15", &exp[2], 32, 0}, - {0, "16-31", &exp[3], 32, 0}, - {0, "0-31:1/2", &exp[4], 32, 0}, - {0, "1-31:1/2", &exp[5], 32, 0}, - {0, "0-31:1/4", &exp[6], 32, 0}, - {0, "1-31:1/4", &exp[7], 32, 0}, - {0, "0-31:4/4", &exp[8], 32, 0}, - {0, "1-31:4/4", &exp[9], 32, 0}, - {0, "0-31:1/4,32-63:2/4", &exp[10], 64, 0}, - {0, "0-31:3/4,32-63:4/4", &exp[11], 64, 0}, + {0, "1", &exp[1 * step], 8, 0}, + {0, "0-15", &exp[2 * step], 32, 0}, + {0, "16-31", &exp[3 * step], 32, 0}, + {0, "0-31:1/2", &exp[4 * step], 32, 0}, + {0, "1-31:1/2", &exp[5 * step], 32, 0}, + {0, "0-31:1/4", &exp[6 * step], 32, 0}, + {0, "1-31:1/4", &exp[7 * step], 32, 0}, + {0, "0-31:4/4", &exp[8 * step], 32, 0}, + {0, "1-31:4/4", &exp[9 * step], 32, 0}, + {0, "0-31:1/4,32-63:2/4", &exp[10 * step], 64, 0}, + {0, "0-31:3/4,32-63:4/4", &exp[11 * step], 64, 0}, {0, "0-31:1/4,32-63:2/4,64-95:3/4,96-127:4/4", exp2, 128, 0}, -- cgit v1.2.3 From da436528267ab45fb44ee52a28ebac85b8e24c3d Mon Sep 17 00:00:00 2001 From: Dan Carpenter Date: Fri, 8 Sep 2017 16:15:48 -0700 Subject: lib/string.c: check for kmalloc() failure This is mostly to keep the number of static checker warnings down so we can spot new bugs instead of them being drowned in noise. This function doesn't return normal kernel error codes but instead the return value is used to display exactly which memory failed. I chose -1 as hopefully that's a helpful thing to print. Link: http://lkml.kernel.org/r/20170817115420.uikisjvfmtrqkzjn@mwanda Signed-off-by: Dan Carpenter Cc: Matthew Wilcox Cc: Stephen Rothwell Cc: Kees Cook Cc: Bjorn Helgaas Cc: Mauro Carvalho Chehab Cc: Heikki Krogerus Cc: Daniel Micay Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- lib/string.c | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) (limited to 'lib') diff --git a/lib/string.c b/lib/string.c index abf6499e3915..9921dc202db4 100644 --- a/lib/string.c +++ b/lib/string.c @@ -1059,7 +1059,11 @@ EXPORT_SYMBOL(fortify_panic); static __init int memset16_selftest(void) { unsigned i, j, k; - u16 v, *p = kmalloc(256 * 2 * 2, GFP_KERNEL); + u16 v, *p; + + p = kmalloc(256 * 2 * 2, GFP_KERNEL); + if (!p) + return -1; for (i = 0; i < 256; i++) { for (j = 0; j < 256; j++) { @@ -1091,7 +1095,11 @@ fail: static __init int memset32_selftest(void) { unsigned i, j, k; - u32 v, *p = kmalloc(256 * 2 * 4, GFP_KERNEL); + u32 v, *p; + + p = kmalloc(256 * 2 * 4, GFP_KERNEL); + if (!p) + return -1; for (i = 0; i < 256; i++) { for (j = 0; j < 256; j++) { @@ -1123,7 +1131,11 @@ fail: static __init int memset64_selftest(void) { unsigned i, j, k; - u64 v, *p = kmalloc(256 * 2 * 8, GFP_KERNEL); + u64 v, *p; + + p = kmalloc(256 * 2 * 8, GFP_KERNEL); + if (!p) + return -1; for (i = 0; i < 256; i++) { for (j = 0; j < 256; j++) { -- cgit v1.2.3 From 7c61bd6983b185272315722787cceacf5f5d2e7d Mon Sep 17 00:00:00 2001 From: Baoquan He Date: Fri, 8 Sep 2017 16:15:51 -0700 Subject: lib/cmdline.c: remove meaningless comment One line of code was commented out by c++ style comment for debugging, but forgot removing it. Clean it up. Link: http://lkml.kernel.org/r/1503312113-11843-1-git-send-email-bhe@redhat.com Signed-off-by: Baoquan He Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- lib/cmdline.c | 1 - 1 file changed, 1 deletion(-) (limited to 'lib') diff --git a/lib/cmdline.c b/lib/cmdline.c index 4c0888c4a68d..171c19b6888e 100644 --- a/lib/cmdline.c +++ b/lib/cmdline.c @@ -244,5 +244,4 @@ char *next_arg(char *args, char **param, char **val) /* Chew up trailing spaces. */ return skip_spaces(next); - //return next; } -- cgit v1.2.3 From bc9ae2247ac92fd4d962507bafa3afffff6660ff Mon Sep 17 00:00:00 2001 From: Eric Dumazet Date: Fri, 8 Sep 2017 16:15:54 -0700 Subject: radix-tree: must check __radix_tree_preload() return value __radix_tree_preload() only disables preemption if no error is returned. So we really need to make sure callers always check the return value. idr_preload() contract is to always disable preemption, so we need to add a missing preempt_disable() if an error happened. Similarly, ida_pre_get() only needs to call preempt_enable() in the case no error happened. Link: http://lkml.kernel.org/r/1504637190.15310.62.camel@edumazet-glaptop3.roam.corp.google.com Fixes: 0a835c4f090a ("Reimplement IDR and IDA using the radix tree") Fixes: 7ad3d4d85c7a ("ida: Move ida_bitmap to a percpu variable") Signed-off-by: Eric Dumazet Cc: Matthew Wilcox Cc: "Kirill A. Shutemov" Cc: [4.11+] Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- lib/radix-tree.c | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'lib') diff --git a/lib/radix-tree.c b/lib/radix-tree.c index 9717e2a50374..8b1feca1230a 100644 --- a/lib/radix-tree.c +++ b/lib/radix-tree.c @@ -463,7 +463,7 @@ radix_tree_node_free(struct radix_tree_node *node) * To make use of this facility, the radix tree must be initialised without * __GFP_DIRECT_RECLAIM being passed to INIT_RADIX_TREE(). */ -static int __radix_tree_preload(gfp_t gfp_mask, unsigned nr) +static __must_check int __radix_tree_preload(gfp_t gfp_mask, unsigned nr) { struct radix_tree_preload *rtp; struct radix_tree_node *node; @@ -2104,7 +2104,8 @@ EXPORT_SYMBOL(radix_tree_tagged); */ void idr_preload(gfp_t gfp_mask) { - __radix_tree_preload(gfp_mask, IDR_PRELOAD_SIZE); + if (__radix_tree_preload(gfp_mask, IDR_PRELOAD_SIZE)) + preempt_disable(); } EXPORT_SYMBOL(idr_preload); @@ -2118,13 +2119,13 @@ EXPORT_SYMBOL(idr_preload); */ int ida_pre_get(struct ida *ida, gfp_t gfp) { - __radix_tree_preload(gfp, IDA_PRELOAD_SIZE); /* * The IDA API has no preload_end() equivalent. Instead, * ida_get_new() can return -EAGAIN, prompting the caller * to return to the ida_pre_get() step. */ - preempt_enable(); + if (!__radix_tree_preload(gfp, IDA_PRELOAD_SIZE)) + preempt_enable(); if (!this_cpu_read(ida_bitmap)) { struct ida_bitmap *bitmap = kmalloc(sizeof(*bitmap), gfp); -- cgit v1.2.3 From afdb05e9d61905220f09268535235288e6ba3a16 Mon Sep 17 00:00:00 2001 From: Takashi Iwai Date: Fri, 8 Sep 2017 16:15:58 -0700 Subject: lib/oid_registry.c: X.509: fix the buffer overflow in the utility function for OID string The sprint_oid() utility function doesn't properly check the buffer size that it causes that the warning in vsnprintf() be triggered. For example on v4.1 kernel: ------------[ cut here ]------------ WARNING: CPU: 0 PID: 2357 at lib/vsprintf.c:1867 vsnprintf+0x5a7/0x5c0() ... We can trigger this issue by injecting maliciously crafted x509 cert in DER format. Just using hex editor to change the length of OID to over the length of the SEQUENCE container. For example: 0:d=0 hl=4 l= 980 cons: SEQUENCE 4:d=1 hl=4 l= 700 cons: SEQUENCE 8:d=2 hl=2 l= 3 cons: cont [ 0 ] 10:d=3 hl=2 l= 1 prim: INTEGER :02 13:d=2 hl=2 l= 9 prim: INTEGER :9B47FAF791E7D1E3 24:d=2 hl=2 l= 13 cons: SEQUENCE 26:d=3 hl=2 l= 9 prim: OBJECT :sha256WithRSAEncryption 37:d=3 hl=2 l= 0 prim: NULL 39:d=2 hl=2 l= 121 cons: SEQUENCE 41:d=3 hl=2 l= 22 cons: SET 43:d=4 hl=2 l= 20 cons: SEQUENCE <=== the SEQ length is 20 45:d=5 hl=2 l= 3 prim: OBJECT :organizationName <=== the original length is 3, change the length of OID to over the length of SEQUENCE Pawel Wieczorkiewicz reported this problem and Takashi Iwai provided patch to fix it by checking the bufsize in sprint_oid(). Link: http://lkml.kernel.org/r/20170903021646.2080-1-jlee@suse.com Signed-off-by: Takashi Iwai Signed-off-by: "Lee, Chun-Yi" Reported-by: Pawel Wieczorkiewicz Cc: David Howells Cc: Rusty Russell Cc: Pawel Wieczorkiewicz Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- lib/oid_registry.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'lib') diff --git a/lib/oid_registry.c b/lib/oid_registry.c index 318f382a010d..41b9e50711a7 100644 --- a/lib/oid_registry.c +++ b/lib/oid_registry.c @@ -142,9 +142,9 @@ int sprint_oid(const void *data, size_t datasize, char *buffer, size_t bufsize) } ret += count = snprintf(buffer, bufsize, ".%lu", num); buffer += count; - bufsize -= count; - if (bufsize == 0) + if (bufsize <= count) return -ENOBUFS; + bufsize -= count; } return ret; -- cgit v1.2.3 From f520409cfd3844660cedef8dc560ed0443e7f192 Mon Sep 17 00:00:00 2001 From: Dan Carpenter Date: Fri, 8 Sep 2017 16:16:53 -0700 Subject: test_kmod: remove paranoid UINT_MAX check on uint range processing The UINT_MAX comparison is not needed because "max" is already an unsigned int, and we expect developer C code max value input to have a sensible 0 - UINT_MAX range. Note that if it so happens to be UINT_MAX + 1 it would lead to an issue, but we expect the developer to know this. [mcgrof@kernel.org: massaged commit log] Link: http://lkml.kernel.org/r/20170802211707.28020-2-mcgrof@kernel.org Signed-off-by: Dan Carpenter Signed-off-by: Luis R. Rodriguez Cc: Dmitry Torokhov Cc: Kees Cook Cc: Jessica Yu Cc: Rusty Russell Cc: Michal Marek Cc: Petr Mladek Cc: Miroslav Benes Cc: Josh Poimboeuf Cc: Eric W. Biederman Cc: Shuah Khan Cc: Colin Ian King Cc: David Binderman Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- lib/test_kmod.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'lib') diff --git a/lib/test_kmod.c b/lib/test_kmod.c index ff9148969b92..67fc7b9f41e3 100644 --- a/lib/test_kmod.c +++ b/lib/test_kmod.c @@ -924,7 +924,7 @@ static int test_dev_config_update_uint_range(struct kmod_test_device *test_dev, if (ret) return ret; - if (new < min || new > max || new > UINT_MAX) + if (new < min || new > max) return -EINVAL; mutex_lock(&test_dev->config_mutex); -- cgit v1.2.3 From 52c270d358b636d332114ac31e572970c5d0c4df Mon Sep 17 00:00:00 2001 From: Dan Carpenter Date: Fri, 8 Sep 2017 16:16:57 -0700 Subject: test_kmod: flip INT checks to be consistent Most checks will check for min and then max, except the int check. Flip the checks to be consistent with the other code. [mcgrof@kernel.org: massaged commit log] Link: http://lkml.kernel.org/r/20170802211707.28020-3-mcgrof@kernel.org Signed-off-by: Dan Carpenter Signed-off-by: Luis R. Rodriguez Cc: Dmitry Torokhov Cc: Kees Cook Cc: Jessica Yu Cc: Rusty Russell Cc: Michal Marek Cc: Petr Mladek Cc: Miroslav Benes Cc: Josh Poimboeuf Cc: Eric W. Biederman Cc: Shuah Khan Cc: Colin Ian King Cc: David Binderman Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- lib/test_kmod.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'lib') diff --git a/lib/test_kmod.c b/lib/test_kmod.c index 67fc7b9f41e3..fba78d25e825 100644 --- a/lib/test_kmod.c +++ b/lib/test_kmod.c @@ -946,7 +946,7 @@ static int test_dev_config_update_int(struct kmod_test_device *test_dev, if (ret) return ret; - if (new > INT_MAX || new < INT_MIN) + if (new < INT_MIN || new > INT_MAX) return -EINVAL; mutex_lock(&test_dev->config_mutex); -- cgit v1.2.3 From f22ef333c32cc683922d7e3361a83ebc31b2ac6d Mon Sep 17 00:00:00 2001 From: Alexey Dobriyan Date: Fri, 8 Sep 2017 16:17:15 -0700 Subject: cpumask: make cpumask_next() out-of-line Every for_each_XXX_cpu() invocation calls cpumask_next() which is an inline function: static inline unsigned int cpumask_next(int n, const struct cpumask *srcp) { /* -1 is a legal arg here. */ if (n != -1) cpumask_check(n); return find_next_bit(cpumask_bits(srcp), nr_cpumask_bits, n + 1); } However! find_next_bit() is regular out-of-line function which means "nr_cpu_ids" load and increment happen at the caller resulting in a lot of bloat x86_64 defconfig: add/remove: 3/0 grow/shrink: 8/373 up/down: 155/-5668 (-5513) x86_64 allyesconfig-ish: add/remove: 3/1 grow/shrink: 57/634 up/down: 3515/-28177 (-24662) !!! Some archs redefine find_next_bit() but it is OK: m68k inline but SMP is not supported arm out-of-line unicore32 out-of-line Function call will happen anyway, so move load and increment into callee. Link: http://lkml.kernel.org/r/20170824230010.GA1593@avx2 Signed-off-by: Alexey Dobriyan Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- include/linux/cpumask.h | 15 +-------------- lib/cpumask.c | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 14 deletions(-) (limited to 'lib') diff --git a/include/linux/cpumask.h b/include/linux/cpumask.h index 68c5a8290275..cd415b733c2a 100644 --- a/include/linux/cpumask.h +++ b/include/linux/cpumask.h @@ -178,20 +178,7 @@ static inline unsigned int cpumask_first(const struct cpumask *srcp) return find_first_bit(cpumask_bits(srcp), nr_cpumask_bits); } -/** - * cpumask_next - get the next cpu in a cpumask - * @n: the cpu prior to the place to search (ie. return will be > @n) - * @srcp: the cpumask pointer - * - * Returns >= nr_cpu_ids if no further cpus set. - */ -static inline unsigned int cpumask_next(int n, const struct cpumask *srcp) -{ - /* -1 is a legal arg here. */ - if (n != -1) - cpumask_check(n); - return find_next_bit(cpumask_bits(srcp), nr_cpumask_bits, n+1); -} +unsigned int cpumask_next(int n, const struct cpumask *srcp); /** * cpumask_next_zero - get the next unset cpu in a cpumask diff --git a/lib/cpumask.c b/lib/cpumask.c index 4731a0895760..8b1a1bd77539 100644 --- a/lib/cpumask.c +++ b/lib/cpumask.c @@ -5,6 +5,22 @@ #include #include +/** + * cpumask_next - get the next cpu in a cpumask + * @n: the cpu prior to the place to search (ie. return will be > @n) + * @srcp: the cpumask pointer + * + * Returns >= nr_cpu_ids if no further cpus set. + */ +unsigned int cpumask_next(int n, const struct cpumask *srcp) +{ + /* -1 is a legal arg here. */ + if (n != -1) + cpumask_check(n); + return find_next_bit(cpumask_bits(srcp), nr_cpumask_bits, n + 1); +} +EXPORT_SYMBOL(cpumask_next); + /** * cpumask_next_and - get the next cpu in *src1p & *src2p * @n: the cpu prior to the place to search (ie. return will be > @n) -- cgit v1.2.3 From a47f68d6a944113bdc8097db6f933c2e17c27bf9 Mon Sep 17 00:00:00 2001 From: Eric Biggers Date: Wed, 13 Sep 2017 16:28:11 -0700 Subject: idr: remove WARN_ON_ONCE() when trying to replace negative ID IDR only supports non-negative IDs. There used to be a 'WARN_ON_ONCE(id < 0)' in idr_replace(), but it was intentionally removed by commit 2e1c9b286765 ("idr: remove WARN_ON_ONCE() on negative IDs"). Then it was added back by commit 0a835c4f090a ("Reimplement IDR and IDA using the radix tree"). However it seems that adding it back was a mistake, given that some users such as drm_gem_handle_delete() (DRM_IOCTL_GEM_CLOSE) pass in a value from userspace to idr_replace(), allowing the WARN_ON_ONCE to be triggered. drm_gem_handle_delete() actually just wants idr_replace() to return an error code if the ID is not allocated, including in the case where the ID is invalid (negative). So once again remove the bogus WARN_ON_ONCE(). This bug was found by syzkaller, which encountered the following warning: WARNING: CPU: 3 PID: 3008 at lib/idr.c:157 idr_replace+0x1d8/0x240 lib/idr.c:157 Kernel panic - not syncing: panic_on_warn set ... CPU: 3 PID: 3008 Comm: syzkaller218828 Not tainted 4.13.0-rc4-next-20170811 #2 Hardware name: QEMU Standard PC (i440FX + PIIX, 1996), BIOS Bochs 01/01/2011 Call Trace: fixup_bug+0x40/0x90 arch/x86/kernel/traps.c:190 do_trap_no_signal arch/x86/kernel/traps.c:224 [inline] do_trap+0x260/0x390 arch/x86/kernel/traps.c:273 do_error_trap+0x120/0x390 arch/x86/kernel/traps.c:310 do_invalid_op+0x1b/0x20 arch/x86/kernel/traps.c:323 invalid_op+0x1e/0x30 arch/x86/entry/entry_64.S:930 RIP: 0010:idr_replace+0x1d8/0x240 lib/idr.c:157 RSP: 0018:ffff8800394bf9f8 EFLAGS: 00010297 RAX: ffff88003c6c60c0 RBX: 1ffff10007297f43 RCX: 0000000000000000 RDX: 0000000000000000 RSI: 0000000000000000 RDI: ffff8800394bfa78 RBP: ffff8800394bfae0 R08: ffffffff82856487 R09: 0000000000000000 R10: ffff8800394bf9a8 R11: ffff88006c8bae28 R12: ffffffffffffffff R13: ffff8800394bfab8 R14: dffffc0000000000 R15: ffff8800394bfbc8 drm_gem_handle_delete+0x33/0xa0 drivers/gpu/drm/drm_gem.c:297 drm_gem_close_ioctl+0xa1/0xe0 drivers/gpu/drm/drm_gem.c:671 drm_ioctl_kernel+0x1e7/0x2e0 drivers/gpu/drm/drm_ioctl.c:729 drm_ioctl+0x72e/0xa50 drivers/gpu/drm/drm_ioctl.c:825 vfs_ioctl fs/ioctl.c:45 [inline] do_vfs_ioctl+0x1b1/0x1520 fs/ioctl.c:685 SYSC_ioctl fs/ioctl.c:700 [inline] SyS_ioctl+0x8f/0xc0 fs/ioctl.c:691 entry_SYSCALL_64_fastpath+0x1f/0xbe Here is a C reproducer: #include #include #include #include #include int main(void) { int cardfd = open("/dev/dri/card0", O_RDONLY); ioctl(cardfd, DRM_IOCTL_GEM_CLOSE, &(struct drm_gem_close) { .handle = -1 } ); } Link: http://lkml.kernel.org/r/20170906235306.20534-1-ebiggers3@gmail.com Fixes: 0a835c4f090a ("Reimplement IDR and IDA using the radix tree") Signed-off-by: Eric Biggers Acked-by: Tejun Heo Cc: Dmitry Vyukov Cc: Matthew Wilcox Cc: [v4.11+] Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- lib/idr.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'lib') diff --git a/lib/idr.c b/lib/idr.c index 082778cf883e..f9adf4805fd7 100644 --- a/lib/idr.c +++ b/lib/idr.c @@ -151,7 +151,7 @@ EXPORT_SYMBOL(idr_get_next_ext); */ void *idr_replace(struct idr *idr, void *ptr, int id) { - if (WARN_ON_ONCE(id < 0)) + if (id < 0) return ERR_PTR(-EINVAL); return idr_replace_ext(idr, ptr, id); -- cgit v1.2.3 From 8185f5708d2488393dc2acec21d19ac53d61b4a0 Mon Sep 17 00:00:00 2001 From: Geert Uytterhoeven Date: Wed, 13 Sep 2017 16:28:20 -0700 Subject: lib/test_bitmap.c: use ULL suffix for 64-bit constants With gcc 4.1.2: lib/test_bitmap.c:189: warning: integer constant is too large for `long' type lib/test_bitmap.c:190: warning: integer constant is too large for `long' type lib/test_bitmap.c:194: warning: integer constant is too large for `long' type lib/test_bitmap.c:195: warning: integer constant is too large for `long' type Add the missing "ULL" suffix to fix this. Link: http://lkml.kernel.org/r/1505040523-31230-1-git-send-email-geert@linux-m68k.org Fixes: 60ef690018b262dd ("bitmap: introduce BITMAP_FROM_U64()") Signed-off-by: Geert Uytterhoeven Acked-by: Yury Norov Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- lib/test_bitmap.c | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'lib') diff --git a/lib/test_bitmap.c b/lib/test_bitmap.c index 599c6713f2a2..aa1f2669bdd5 100644 --- a/lib/test_bitmap.c +++ b/lib/test_bitmap.c @@ -186,13 +186,13 @@ static const unsigned long exp[] __initconst = { BITMAP_FROM_U64(0x22222222), BITMAP_FROM_U64(0xffffffff), BITMAP_FROM_U64(0xfffffffe), - BITMAP_FROM_U64(0x3333333311111111), - BITMAP_FROM_U64(0xffffffff77777777) + BITMAP_FROM_U64(0x3333333311111111ULL), + BITMAP_FROM_U64(0xffffffff77777777ULL) }; static const unsigned long exp2[] __initconst = { - BITMAP_FROM_U64(0x3333333311111111), - BITMAP_FROM_U64(0xffffffff77777777) + BITMAP_FROM_U64(0x3333333311111111ULL), + BITMAP_FROM_U64(0xffffffff77777777ULL) }; static const struct test_bitmap_parselist parselist_tests[] __initconst = { -- cgit v1.2.3 From 0ee931c4e31a5efb134c76440405e9219f896e33 Mon Sep 17 00:00:00 2001 From: Michal Hocko Date: Wed, 13 Sep 2017 16:28:29 -0700 Subject: mm: treewide: remove GFP_TEMPORARY allocation flag GFP_TEMPORARY was introduced by commit e12ba74d8ff3 ("Group short-lived and reclaimable kernel allocations") along with __GFP_RECLAIMABLE. It's primary motivation was to allow users to tell that an allocation is short lived and so the allocator can try to place such allocations close together and prevent long term fragmentation. As much as this sounds like a reasonable semantic it becomes much less clear when to use the highlevel GFP_TEMPORARY allocation flag. How long is temporary? Can the context holding that memory sleep? Can it take locks? It seems there is no good answer for those questions. The current implementation of GFP_TEMPORARY is basically GFP_KERNEL | __GFP_RECLAIMABLE which in itself is tricky because basically none of the existing caller provide a way to reclaim the allocated memory. So this is rather misleading and hard to evaluate for any benefits. I have checked some random users and none of them has added the flag with a specific justification. I suspect most of them just copied from other existing users and others just thought it might be a good idea to use without any measuring. This suggests that GFP_TEMPORARY just motivates for cargo cult usage without any reasoning. I believe that our gfp flags are quite complex already and especially those with highlevel semantic should be clearly defined to prevent from confusion and abuse. Therefore I propose dropping GFP_TEMPORARY and replace all existing users to simply use GFP_KERNEL. Please note that SLAB users with shrinkers will still get __GFP_RECLAIMABLE heuristic and so they will be placed properly for memory fragmentation prevention. I can see reasons we might want some gfp flag to reflect shorterm allocations but I propose starting from a clear semantic definition and only then add users with proper justification. This was been brought up before LSF this year by Matthew [1] and it turned out that GFP_TEMPORARY really doesn't have a clear semantic. It seems to be a heuristic without any measured advantage for most (if not all) its current users. The follow up discussion has revealed that opinions on what might be temporary allocation differ a lot between developers. So rather than trying to tweak existing users into a semantic which they haven't expected I propose to simply remove the flag and start from scratch if we really need a semantic for short term allocations. [1] http://lkml.kernel.org/r/20170118054945.GD18349@bombadil.infradead.org [akpm@linux-foundation.org: fix typo] [akpm@linux-foundation.org: coding-style fixes] [sfr@canb.auug.org.au: drm/i915: fix up] Link: http://lkml.kernel.org/r/20170816144703.378d4f4d@canb.auug.org.au Link: http://lkml.kernel.org/r/20170728091904.14627-1-mhocko@kernel.org Signed-off-by: Michal Hocko Signed-off-by: Stephen Rothwell Acked-by: Mel Gorman Acked-by: Vlastimil Babka Cc: Matthew Wilcox Cc: Neil Brown Cc: "Theodore Ts'o" Signed-off-by: Andrew Morton Signed-off-by: Linus Torvalds --- arch/arc/kernel/setup.c | 2 +- arch/arc/kernel/troubleshoot.c | 2 +- arch/powerpc/kernel/rtas.c | 4 ++-- arch/powerpc/platforms/pseries/suspend.c | 2 +- drivers/gpu/drm/drm_blend.c | 2 +- drivers/gpu/drm/drm_dp_dual_mode_helper.c | 2 +- drivers/gpu/drm/drm_scdc_helper.c | 2 +- drivers/gpu/drm/etnaviv/etnaviv_gem_submit.c | 2 +- drivers/gpu/drm/i915/i915_gem.c | 2 +- drivers/gpu/drm/i915/i915_gem_execbuffer.c | 12 ++++++------ drivers/gpu/drm/i915/i915_gem_gtt.c | 2 +- drivers/gpu/drm/i915/i915_gem_userptr.c | 4 ++-- drivers/gpu/drm/i915/i915_gpu_error.c | 6 +++--- drivers/gpu/drm/i915/selftests/i915_random.c | 2 +- drivers/gpu/drm/i915/selftests/intel_breadcrumbs.c | 10 +++++----- drivers/gpu/drm/i915/selftests/intel_uncore.c | 2 +- drivers/gpu/drm/lib/drm_random.c | 2 +- drivers/gpu/drm/msm/msm_gem_submit.c | 2 +- drivers/gpu/drm/selftests/test-drm_mm.c | 4 ++-- drivers/misc/cxl/pci.c | 2 +- drivers/xen/gntalloc.c | 2 +- fs/coredump.c | 2 +- fs/exec.c | 4 ++-- fs/overlayfs/copy_up.c | 2 +- fs/overlayfs/dir.c | 2 +- fs/overlayfs/namei.c | 12 ++++++------ fs/proc/base.c | 8 ++++---- fs/proc/task_mmu.c | 2 +- include/linux/gfp.h | 2 -- include/trace/events/mmflags.h | 1 - kernel/locking/test-ww_mutex.c | 2 +- kernel/trace/trace_events_filter.c | 2 +- lib/string_helpers.c | 4 ++-- mm/shmem.c | 2 +- mm/slub.c | 2 +- tools/perf/builtin-kmem.c | 1 - 36 files changed, 57 insertions(+), 61 deletions(-) (limited to 'lib') diff --git a/arch/arc/kernel/setup.c b/arch/arc/kernel/setup.c index c4ffb441716c..877cec8f5ea2 100644 --- a/arch/arc/kernel/setup.c +++ b/arch/arc/kernel/setup.c @@ -510,7 +510,7 @@ static int show_cpuinfo(struct seq_file *m, void *v) goto done; } - str = (char *)__get_free_page(GFP_TEMPORARY); + str = (char *)__get_free_page(GFP_KERNEL); if (!str) goto done; diff --git a/arch/arc/kernel/troubleshoot.c b/arch/arc/kernel/troubleshoot.c index 7e94476f3994..7d8c1d6c2f60 100644 --- a/arch/arc/kernel/troubleshoot.c +++ b/arch/arc/kernel/troubleshoot.c @@ -178,7 +178,7 @@ void show_regs(struct pt_regs *regs) struct callee_regs *cregs; char *buf; - buf = (char *)__get_free_page(GFP_TEMPORARY); + buf = (char *)__get_free_page(GFP_KERNEL); if (!buf) return; diff --git a/arch/powerpc/kernel/rtas.c b/arch/powerpc/kernel/rtas.c index b8a4987f58cf..1643e9e53655 100644 --- a/arch/powerpc/kernel/rtas.c +++ b/arch/powerpc/kernel/rtas.c @@ -914,7 +914,7 @@ int rtas_online_cpus_mask(cpumask_var_t cpus) if (ret) { cpumask_var_t tmp_mask; - if (!alloc_cpumask_var(&tmp_mask, GFP_TEMPORARY)) + if (!alloc_cpumask_var(&tmp_mask, GFP_KERNEL)) return ret; /* Use tmp_mask to preserve cpus mask from first failure */ @@ -962,7 +962,7 @@ int rtas_ibm_suspend_me(u64 handle) return -EIO; } - if (!alloc_cpumask_var(&offline_mask, GFP_TEMPORARY)) + if (!alloc_cpumask_var(&offline_mask, GFP_KERNEL)) return -ENOMEM; atomic_set(&data.working, 0); diff --git a/arch/powerpc/platforms/pseries/suspend.c b/arch/powerpc/platforms/pseries/suspend.c index e76aefae2aa2..89726f07d249 100644 --- a/arch/powerpc/platforms/pseries/suspend.c +++ b/arch/powerpc/platforms/pseries/suspend.c @@ -151,7 +151,7 @@ static ssize_t store_hibernate(struct device *dev, if (!capable(CAP_SYS_ADMIN)) return -EPERM; - if (!alloc_cpumask_var(&offline_mask, GFP_TEMPORARY)) + if (!alloc_cpumask_var(&offline_mask, GFP_KERNEL)) return -ENOMEM; stream_id = simple_strtoul(buf, NULL, 16); diff --git a/drivers/gpu/drm/drm_blend.c b/drivers/gpu/drm/drm_blend.c index db6aeec50b82..2e5e089dd912 100644 --- a/drivers/gpu/drm/drm_blend.c +++ b/drivers/gpu/drm/drm_blend.c @@ -319,7 +319,7 @@ static int drm_atomic_helper_crtc_normalize_zpos(struct drm_crtc *crtc, DRM_DEBUG_ATOMIC("[CRTC:%d:%s] calculating normalized zpos values\n", crtc->base.id, crtc->name); - states = kmalloc_array(total_planes, sizeof(*states), GFP_TEMPORARY); + states = kmalloc_array(total_planes, sizeof(*states), GFP_KERNEL); if (!states) return -ENOMEM; diff --git a/drivers/gpu/drm/drm_dp_dual_mode_helper.c b/drivers/gpu/drm/drm_dp_dual_mode_helper.c index 80e62f669321..0ef9011a1856 100644 --- a/drivers/gpu/drm/drm_dp_dual_mode_helper.c +++ b/drivers/gpu/drm/drm_dp_dual_mode_helper.c @@ -111,7 +111,7 @@ ssize_t drm_dp_dual_mode_write(struct i2c_adapter *adapter, void *data; int ret; - data = kmalloc(msg.len, GFP_TEMPORARY); + data = kmalloc(msg.len, GFP_KERNEL); if (!data) return -ENOMEM; diff --git a/drivers/gpu/drm/drm_scdc_helper.c b/drivers/gpu/drm/drm_scdc_helper.c index 7d1b0f011d33..935653eb3616 100644 --- a/drivers/gpu/drm/drm_scdc_helper.c +++ b/drivers/gpu/drm/drm_scdc_helper.c @@ -102,7 +102,7 @@ ssize_t drm_scdc_write(struct i2c_adapter *adapter, u8 offset, void *data; int err; - data = kmalloc(1 + size, GFP_TEMPORARY); + data = kmalloc(1 + size, GFP_KERNEL); if (!data) return -ENOMEM; diff --git a/drivers/gpu/drm/etnaviv/etnaviv_gem_submit.c b/drivers/gpu/drm/etnaviv/etnaviv_gem_submit.c index a7ff2e4c00d2..026ef4e02f85 100644 --- a/drivers/gpu/drm/etnaviv/etnaviv_gem_submit.c +++ b/drivers/gpu/drm/etnaviv/etnaviv_gem_submit.c @@ -37,7 +37,7 @@ static struct etnaviv_gem_submit *submit_create(struct drm_device *dev, struct etnaviv_gem_submit *submit; size_t sz = size_vstruct(nr, sizeof(submit->bos[0]), sizeof(*submit)); - submit = kmalloc(sz, GFP_TEMPORARY | __GFP_NOWARN | __GFP_NORETRY); + submit = kmalloc(sz, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY); if (submit) { submit->dev = dev; submit->gpu = gpu; diff --git a/drivers/gpu/drm/i915/i915_gem.c b/drivers/gpu/drm/i915/i915_gem.c index 57317715977f..19404c96eeb1 100644 --- a/drivers/gpu/drm/i915/i915_gem.c +++ b/drivers/gpu/drm/i915/i915_gem.c @@ -2540,7 +2540,7 @@ static void *i915_gem_object_map(const struct drm_i915_gem_object *obj, if (n_pages > ARRAY_SIZE(stack_pages)) { /* Too big for stack -- allocate temporary array instead */ - pages = kvmalloc_array(n_pages, sizeof(*pages), GFP_TEMPORARY); + pages = kvmalloc_array(n_pages, sizeof(*pages), GFP_KERNEL); if (!pages) return NULL; } diff --git a/drivers/gpu/drm/i915/i915_gem_execbuffer.c b/drivers/gpu/drm/i915/i915_gem_execbuffer.c index 50d5e24f91a9..92437f455b43 100644 --- a/drivers/gpu/drm/i915/i915_gem_execbuffer.c +++ b/drivers/gpu/drm/i915/i915_gem_execbuffer.c @@ -293,7 +293,7 @@ static int eb_create(struct i915_execbuffer *eb) * as possible to perform the allocation and warn * if it fails. */ - flags = GFP_TEMPORARY; + flags = GFP_KERNEL; if (size > 1) flags |= __GFP_NORETRY | __GFP_NOWARN; @@ -1515,7 +1515,7 @@ static int eb_copy_relocations(const struct i915_execbuffer *eb) urelocs = u64_to_user_ptr(eb->exec[i].relocs_ptr); size = nreloc * sizeof(*relocs); - relocs = kvmalloc_array(size, 1, GFP_TEMPORARY); + relocs = kvmalloc_array(size, 1, GFP_KERNEL); if (!relocs) { kvfree(relocs); err = -ENOMEM; @@ -2077,7 +2077,7 @@ get_fence_array(struct drm_i915_gem_execbuffer2 *args, return ERR_PTR(-EFAULT); fences = kvmalloc_array(args->num_cliprects, sizeof(*fences), - __GFP_NOWARN | GFP_TEMPORARY); + __GFP_NOWARN | GFP_KERNEL); if (!fences) return ERR_PTR(-ENOMEM); @@ -2463,9 +2463,9 @@ i915_gem_execbuffer(struct drm_device *dev, void *data, /* Copy in the exec list from userland */ exec_list = kvmalloc_array(args->buffer_count, sizeof(*exec_list), - __GFP_NOWARN | GFP_TEMPORARY); + __GFP_NOWARN | GFP_KERNEL); exec2_list = kvmalloc_array(args->buffer_count + 1, sz, - __GFP_NOWARN | GFP_TEMPORARY); + __GFP_NOWARN | GFP_KERNEL); if (exec_list == NULL || exec2_list == NULL) { DRM_DEBUG("Failed to allocate exec list for %d buffers\n", args->buffer_count); @@ -2543,7 +2543,7 @@ i915_gem_execbuffer2(struct drm_device *dev, void *data, /* Allocate an extra slot for use by the command parser */ exec2_list = kvmalloc_array(args->buffer_count + 1, sz, - __GFP_NOWARN | GFP_TEMPORARY); + __GFP_NOWARN | GFP_KERNEL); if (exec2_list == NULL) { DRM_DEBUG("Failed to allocate exec list for %d buffers\n", args->buffer_count); diff --git a/drivers/gpu/drm/i915/i915_gem_gtt.c b/drivers/gpu/drm/i915/i915_gem_gtt.c index 0d5a988b3867..e2410eb5d96e 100644 --- a/drivers/gpu/drm/i915/i915_gem_gtt.c +++ b/drivers/gpu/drm/i915/i915_gem_gtt.c @@ -3231,7 +3231,7 @@ intel_rotate_pages(struct intel_rotation_info *rot_info, /* Allocate a temporary list of source pages for random access. */ page_addr_list = kvmalloc_array(n_pages, sizeof(dma_addr_t), - GFP_TEMPORARY); + GFP_KERNEL); if (!page_addr_list) return ERR_PTR(ret); diff --git a/drivers/gpu/drm/i915/i915_gem_userptr.c b/drivers/gpu/drm/i915/i915_gem_userptr.c index 23fd18bd1b56..709efe2357ea 100644 --- a/drivers/gpu/drm/i915/i915_gem_userptr.c +++ b/drivers/gpu/drm/i915/i915_gem_userptr.c @@ -507,7 +507,7 @@ __i915_gem_userptr_get_pages_worker(struct work_struct *_work) ret = -ENOMEM; pinned = 0; - pvec = kvmalloc_array(npages, sizeof(struct page *), GFP_TEMPORARY); + pvec = kvmalloc_array(npages, sizeof(struct page *), GFP_KERNEL); if (pvec != NULL) { struct mm_struct *mm = obj->userptr.mm->mm; unsigned int flags = 0; @@ -643,7 +643,7 @@ i915_gem_userptr_get_pages(struct drm_i915_gem_object *obj) if (mm == current->mm) { pvec = kvmalloc_array(num_pages, sizeof(struct page *), - GFP_TEMPORARY | + GFP_KERNEL | __GFP_NORETRY | __GFP_NOWARN); if (pvec) /* defer to worker if malloc fails */ diff --git a/drivers/gpu/drm/i915/i915_gpu_error.c b/drivers/gpu/drm/i915/i915_gpu_error.c index ed5a1eb839ad..0c779671fe2d 100644 --- a/drivers/gpu/drm/i915/i915_gpu_error.c +++ b/drivers/gpu/drm/i915/i915_gpu_error.c @@ -787,16 +787,16 @@ int i915_error_state_buf_init(struct drm_i915_error_state_buf *ebuf, */ ebuf->size = count + 1 > PAGE_SIZE ? count + 1 : PAGE_SIZE; ebuf->buf = kmalloc(ebuf->size, - GFP_TEMPORARY | __GFP_NORETRY | __GFP_NOWARN); + GFP_KERNEL | __GFP_NORETRY | __GFP_NOWARN); if (ebuf->buf == NULL) { ebuf->size = PAGE_SIZE; - ebuf->buf = kmalloc(ebuf->size, GFP_TEMPORARY); + ebuf->buf = kmalloc(ebuf->size, GFP_KERNEL); } if (ebuf->buf == NULL) { ebuf->size = 128; - ebuf->buf = kmalloc(ebuf->size, GFP_TEMPORARY); + ebuf->buf = kmalloc(ebuf->size, GFP_KERNEL); } if (ebuf->buf == NULL) diff --git a/drivers/gpu/drm/i915/selftests/i915_random.c b/drivers/gpu/drm/i915/selftests/i915_random.c index d044bf9a6feb..222c511bea49 100644 --- a/drivers/gpu/drm/i915/selftests/i915_random.c +++ b/drivers/gpu/drm/i915/selftests/i915_random.c @@ -62,7 +62,7 @@ unsigned int *i915_random_order(unsigned int count, struct rnd_state *state) { unsigned int *order, i; - order = kmalloc_array(count, sizeof(*order), GFP_TEMPORARY); + order = kmalloc_array(count, sizeof(*order), GFP_KERNEL); if (!order) return order; diff --git a/drivers/gpu/drm/i915/selftests/intel_breadcrumbs.c b/drivers/gpu/drm/i915/selftests/intel_breadcrumbs.c index 7276194c04f7..828904b7d468 100644 --- a/drivers/gpu/drm/i915/selftests/intel_breadcrumbs.c +++ b/drivers/gpu/drm/i915/selftests/intel_breadcrumbs.c @@ -117,12 +117,12 @@ static int igt_random_insert_remove(void *arg) mock_engine_reset(engine); - waiters = kvmalloc_array(count, sizeof(*waiters), GFP_TEMPORARY); + waiters = kvmalloc_array(count, sizeof(*waiters), GFP_KERNEL); if (!waiters) goto out_engines; bitmap = kcalloc(DIV_ROUND_UP(count, BITS_PER_LONG), sizeof(*bitmap), - GFP_TEMPORARY); + GFP_KERNEL); if (!bitmap) goto out_waiters; @@ -187,12 +187,12 @@ static int igt_insert_complete(void *arg) mock_engine_reset(engine); - waiters = kvmalloc_array(count, sizeof(*waiters), GFP_TEMPORARY); + waiters = kvmalloc_array(count, sizeof(*waiters), GFP_KERNEL); if (!waiters) goto out_engines; bitmap = kcalloc(DIV_ROUND_UP(count, BITS_PER_LONG), sizeof(*bitmap), - GFP_TEMPORARY); + GFP_KERNEL); if (!bitmap) goto out_waiters; @@ -368,7 +368,7 @@ static int igt_wakeup(void *arg) mock_engine_reset(engine); - waiters = kvmalloc_array(count, sizeof(*waiters), GFP_TEMPORARY); + waiters = kvmalloc_array(count, sizeof(*waiters), GFP_KERNEL); if (!waiters) goto out_engines; diff --git a/drivers/gpu/drm/i915/selftests/intel_uncore.c b/drivers/gpu/drm/i915/selftests/intel_uncore.c index 2d0fef2cfca6..3cac22eb47ce 100644 --- a/drivers/gpu/drm/i915/selftests/intel_uncore.c +++ b/drivers/gpu/drm/i915/selftests/intel_uncore.c @@ -127,7 +127,7 @@ static int intel_uncore_check_forcewake_domains(struct drm_i915_private *dev_pri return 0; valid = kzalloc(BITS_TO_LONGS(FW_RANGE) * sizeof(*valid), - GFP_TEMPORARY); + GFP_KERNEL); if (!valid) return -ENOMEM; diff --git a/drivers/gpu/drm/lib/drm_random.c b/drivers/gpu/drm/lib/drm_random.c index 7b12a68c3b54..a78c4b483e8d 100644 --- a/drivers/gpu/drm/lib/drm_random.c +++ b/drivers/gpu/drm/lib/drm_random.c @@ -28,7 +28,7 @@ unsigned int *drm_random_order(unsigned int count, struct rnd_state *state) { unsigned int *order, i; - order = kmalloc_array(count, sizeof(*order), GFP_TEMPORARY); + order = kmalloc_array(count, sizeof(*order), GFP_KERNEL); if (!order) return order; diff --git a/drivers/gpu/drm/msm/msm_gem_submit.c b/drivers/gpu/drm/msm/msm_gem_submit.c index 8a75c0bd8a78..5d0a75d4b249 100644 --- a/drivers/gpu/drm/msm/msm_gem_submit.c +++ b/drivers/gpu/drm/msm/msm_gem_submit.c @@ -40,7 +40,7 @@ static struct msm_gem_submit *submit_create(struct drm_device *dev, if (sz > SIZE_MAX) return NULL; - submit = kmalloc(sz, GFP_TEMPORARY | __GFP_NOWARN | __GFP_NORETRY); + submit = kmalloc(sz, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY); if (!submit) return NULL; diff --git a/drivers/gpu/drm/selftests/test-drm_mm.c b/drivers/gpu/drm/selftests/test-drm_mm.c index dfdd858eda0a..86eb4c185a28 100644 --- a/drivers/gpu/drm/selftests/test-drm_mm.c +++ b/drivers/gpu/drm/selftests/test-drm_mm.c @@ -1627,7 +1627,7 @@ static int igt_topdown(void *ignored) goto err; bitmap = kzalloc(count / BITS_PER_LONG * sizeof(unsigned long), - GFP_TEMPORARY); + GFP_KERNEL); if (!bitmap) goto err_nodes; @@ -1741,7 +1741,7 @@ static int igt_bottomup(void *ignored) goto err; bitmap = kzalloc(count / BITS_PER_LONG * sizeof(unsigned long), - GFP_TEMPORARY); + GFP_KERNEL); if (!bitmap) goto err_nodes; diff --git a/drivers/misc/cxl/pci.c b/drivers/misc/cxl/pci.c index d18b3d9292fd..3ba04f371380 100644 --- a/drivers/misc/cxl/pci.c +++ b/drivers/misc/cxl/pci.c @@ -1279,7 +1279,7 @@ ssize_t cxl_pci_afu_read_err_buffer(struct cxl_afu *afu, char *buf, } /* use bounce buffer for copy */ - tbuf = (void *)__get_free_page(GFP_TEMPORARY); + tbuf = (void *)__get_free_page(GFP_KERNEL); if (!tbuf) return -ENOMEM; diff --git a/drivers/xen/gntalloc.c b/drivers/xen/gntalloc.c index 1bf55a32a4b3..3fa40c723e8e 100644 --- a/drivers/xen/gntalloc.c +++ b/drivers/xen/gntalloc.c @@ -294,7 +294,7 @@ static long gntalloc_ioctl_alloc(struct gntalloc_file_private_data *priv, goto out; } - gref_ids = kcalloc(op.count, sizeof(gref_ids[0]), GFP_TEMPORARY); + gref_ids = kcalloc(op.count, sizeof(gref_ids[0]), GFP_KERNEL); if (!gref_ids) { rc = -ENOMEM; goto out; diff --git a/fs/coredump.c b/fs/coredump.c index 592683711c64..0eec03696707 100644 --- a/fs/coredump.c +++ b/fs/coredump.c @@ -161,7 +161,7 @@ static int cn_print_exe_file(struct core_name *cn) if (!exe_file) return cn_esc_printf(cn, "%s (path unknown)", current->comm); - pathbuf = kmalloc(PATH_MAX, GFP_TEMPORARY); + pathbuf = kmalloc(PATH_MAX, GFP_KERNEL); if (!pathbuf) { ret = -ENOMEM; goto put_exe_file; diff --git a/fs/exec.c b/fs/exec.c index 01a9fb9d8ac3..daa19d85c066 100644 --- a/fs/exec.c +++ b/fs/exec.c @@ -1763,9 +1763,9 @@ static int do_execveat_common(int fd, struct filename *filename, bprm->filename = filename->name; } else { if (filename->name[0] == '\0') - pathbuf = kasprintf(GFP_TEMPORARY, "/dev/fd/%d", fd); + pathbuf = kasprintf(GFP_KERNEL, "/dev/fd/%d", fd); else - pathbuf = kasprintf(GFP_TEMPORARY, "/dev/fd/%d/%s", + pathbuf = kasprintf(GFP_KERNEL, "/dev/fd/%d/%s", fd, filename->name); if (!pathbuf) { retval = -ENOMEM; diff --git a/fs/overlayfs/copy_up.c b/fs/overlayfs/copy_up.c index acb6f97deb97..aad97b30d5e6 100644 --- a/fs/overlayfs/copy_up.c +++ b/fs/overlayfs/copy_up.c @@ -241,7 +241,7 @@ struct ovl_fh *ovl_encode_fh(struct dentry *lower, bool is_upper) int buflen = MAX_HANDLE_SZ; uuid_t *uuid = &lower->d_sb->s_uuid; - buf = kmalloc(buflen, GFP_TEMPORARY); + buf = kmalloc(buflen, GFP_KERNEL); if (!buf) return ERR_PTR(-ENOMEM); diff --git a/fs/overlayfs/dir.c b/fs/overlayfs/dir.c index 9cb0c80e5967..3309b1912241 100644 --- a/fs/overlayfs/dir.c +++ b/fs/overlayfs/dir.c @@ -833,7 +833,7 @@ static char *ovl_get_redirect(struct dentry *dentry, bool samedir) goto out; } - buf = ret = kmalloc(buflen, GFP_TEMPORARY); + buf = ret = kmalloc(buflen, GFP_KERNEL); if (!buf) goto out; diff --git a/fs/overlayfs/namei.c b/fs/overlayfs/namei.c index 8aef2b304b2d..c3addd1114f1 100644 --- a/fs/overlayfs/namei.c +++ b/fs/overlayfs/namei.c @@ -38,7 +38,7 @@ static int ovl_check_redirect(struct dentry *dentry, struct ovl_lookup_data *d, return 0; goto fail; } - buf = kzalloc(prelen + res + strlen(post) + 1, GFP_TEMPORARY); + buf = kzalloc(prelen + res + strlen(post) + 1, GFP_KERNEL); if (!buf) return -ENOMEM; @@ -103,7 +103,7 @@ static struct ovl_fh *ovl_get_origin_fh(struct dentry *dentry) if (res == 0) return NULL; - fh = kzalloc(res, GFP_TEMPORARY); + fh = kzalloc(res, GFP_KERNEL); if (!fh) return ERR_PTR(-ENOMEM); @@ -309,7 +309,7 @@ static int ovl_check_origin(struct dentry *upperdentry, BUG_ON(*ctrp); if (!*stackp) - *stackp = kmalloc(sizeof(struct path), GFP_TEMPORARY); + *stackp = kmalloc(sizeof(struct path), GFP_KERNEL); if (!*stackp) { dput(origin); return -ENOMEM; @@ -418,7 +418,7 @@ int ovl_verify_index(struct dentry *index, struct path *lowerstack, err = -ENOMEM; len = index->d_name.len / 2; - fh = kzalloc(len, GFP_TEMPORARY); + fh = kzalloc(len, GFP_KERNEL); if (!fh) goto fail; @@ -478,7 +478,7 @@ int ovl_get_index_name(struct dentry *origin, struct qstr *name) return PTR_ERR(fh); err = -ENOMEM; - n = kzalloc(fh->len * 2, GFP_TEMPORARY); + n = kzalloc(fh->len * 2, GFP_KERNEL); if (n) { s = bin2hex(n, fh, fh->len); *name = (struct qstr) QSTR_INIT(n, s - n); @@ -646,7 +646,7 @@ struct dentry *ovl_lookup(struct inode *dir, struct dentry *dentry, if (!d.stop && poe->numlower) { err = -ENOMEM; stack = kcalloc(ofs->numlower, sizeof(struct path), - GFP_TEMPORARY); + GFP_KERNEL); if (!stack) goto out_put_upper; } diff --git a/fs/proc/base.c b/fs/proc/base.c index e5d89a0d0b8a..ad3b0762cc3e 100644 --- a/fs/proc/base.c +++ b/fs/proc/base.c @@ -232,7 +232,7 @@ static ssize_t proc_pid_cmdline_read(struct file *file, char __user *buf, goto out_mmput; } - page = (char *)__get_free_page(GFP_TEMPORARY); + page = (char *)__get_free_page(GFP_KERNEL); if (!page) { rv = -ENOMEM; goto out_mmput; @@ -813,7 +813,7 @@ static ssize_t mem_rw(struct file *file, char __user *buf, if (!mm) return 0; - page = (char *)__get_free_page(GFP_TEMPORARY); + page = (char *)__get_free_page(GFP_KERNEL); if (!page) return -ENOMEM; @@ -918,7 +918,7 @@ static ssize_t environ_read(struct file *file, char __user *buf, if (!mm || !mm->env_end) return 0; - page = (char *)__get_free_page(GFP_TEMPORARY); + page = (char *)__get_free_page(GFP_KERNEL); if (!page) return -ENOMEM; @@ -1630,7 +1630,7 @@ out: static int do_proc_readlink(struct path *path, char __user *buffer, int buflen) { - char *tmp = (char*)__get_free_page(GFP_TEMPORARY); + char *tmp = (char *)__get_free_page(GFP_KERNEL); char *pathname; int len; diff --git a/fs/proc/task_mmu.c b/fs/proc/task_mmu.c index 7b40e11ede9b..5589b4bd4b85 100644 --- a/fs/proc/task_mmu.c +++ b/fs/proc/task_mmu.c @@ -1474,7 +1474,7 @@ static ssize_t pagemap_read(struct file *file, char __user *buf, pm.show_pfn = file_ns_capable(file, &init_user_ns, CAP_SYS_ADMIN); pm.len = (PAGEMAP_WALK_SIZE >> PAGE_SHIFT); - pm.buffer = kmalloc(pm.len * PM_ENTRY_BYTES, GFP_TEMPORARY); + pm.buffer = kmalloc(pm.len * PM_ENTRY_BYTES, GFP_KERNEL); ret = -ENOMEM; if (!pm.buffer) goto out_mm; diff --git a/include/linux/gfp.h b/include/linux/gfp.h index bcfb9f7c46f5..f780718b7391 100644 --- a/include/linux/gfp.h +++ b/include/linux/gfp.h @@ -288,8 +288,6 @@ struct vm_area_struct; #define GFP_NOWAIT (__GFP_KSWAPD_RECLAIM) #define GFP_NOIO (__GFP_RECLAIM) #define GFP_NOFS (__GFP_RECLAIM | __GFP_IO) -#define GFP_TEMPORARY (__GFP_RECLAIM | __GFP_IO | __GFP_FS | \ - __GFP_RECLAIMABLE) #define GFP_USER (__GFP_RECLAIM | __GFP_IO | __GFP_FS | __GFP_HARDWALL) #define GFP_DMA __GFP_DMA #define GFP_DMA32 __GFP_DMA32 diff --git a/include/trace/events/mmflags.h b/include/trace/events/mmflags.h index 4c2e4737d7bc..fec6291a6703 100644 --- a/include/trace/events/mmflags.h +++ b/include/trace/events/mmflags.h @@ -18,7 +18,6 @@ {(unsigned long)GFP_HIGHUSER_MOVABLE, "GFP_HIGHUSER_MOVABLE"},\ {(unsigned long)GFP_HIGHUSER, "GFP_HIGHUSER"}, \ {(unsigned long)GFP_USER, "GFP_USER"}, \ - {(unsigned long)GFP_TEMPORARY, "GFP_TEMPORARY"}, \ {(unsigned long)GFP_KERNEL_ACCOUNT, "GFP_KERNEL_ACCOUNT"}, \ {(unsigned long)GFP_KERNEL, "GFP_KERNEL"}, \ {(unsigned long)GFP_NOFS, "GFP_NOFS"}, \ diff --git a/kernel/locking/test-ww_mutex.c b/kernel/locking/test-ww_mutex.c index 39f56c870051..0e4cd64ad2c0 100644 --- a/kernel/locking/test-ww_mutex.c +++ b/kernel/locking/test-ww_mutex.c @@ -362,7 +362,7 @@ static int *get_random_order(int count) int *order; int n, r, tmp; - order = kmalloc_array(count, sizeof(*order), GFP_TEMPORARY); + order = kmalloc_array(count, sizeof(*order), GFP_KERNEL); if (!order) return order; diff --git a/kernel/trace/trace_events_filter.c b/kernel/trace/trace_events_filter.c index 181e139a8057..61e7f0678d33 100644 --- a/kernel/trace/trace_events_filter.c +++ b/kernel/trace/trace_events_filter.c @@ -702,7 +702,7 @@ static void append_filter_err(struct filter_parse_state *ps, int pos = ps->lasterr_pos; char *buf, *pbuf; - buf = (char *)__get_free_page(GFP_TEMPORARY); + buf = (char *)__get_free_page(GFP_KERNEL); if (!buf) return; diff --git a/lib/string_helpers.c b/lib/string_helpers.c index ecaac2c0526f..29c490e5d478 100644 --- a/lib/string_helpers.c +++ b/lib/string_helpers.c @@ -576,7 +576,7 @@ char *kstrdup_quotable_cmdline(struct task_struct *task, gfp_t gfp) char *buffer, *quoted; int i, res; - buffer = kmalloc(PAGE_SIZE, GFP_TEMPORARY); + buffer = kmalloc(PAGE_SIZE, GFP_KERNEL); if (!buffer) return NULL; @@ -612,7 +612,7 @@ char *kstrdup_quotable_file(struct file *file, gfp_t gfp) return kstrdup("", gfp); /* We add 11 spaces for ' (deleted)' to be appended */ - temp = kmalloc(PATH_MAX + 11, GFP_TEMPORARY); + temp = kmalloc(PATH_MAX + 11, GFP_KERNEL); if (!temp) return kstrdup("", gfp); diff --git a/mm/shmem.c b/mm/shmem.c index ace53a582be5..07a1d22807be 100644 --- a/mm/shmem.c +++ b/mm/shmem.c @@ -3685,7 +3685,7 @@ SYSCALL_DEFINE2(memfd_create, if (len > MFD_NAME_MAX_LEN + 1) return -EINVAL; - name = kmalloc(len + MFD_NAME_PREFIX_LEN, GFP_TEMPORARY); + name = kmalloc(len + MFD_NAME_PREFIX_LEN, GFP_KERNEL); if (!name) return -ENOMEM; diff --git a/mm/slub.c b/mm/slub.c index d39a5d3834b3..163352c537ab 100644 --- a/mm/slub.c +++ b/mm/slub.c @@ -4597,7 +4597,7 @@ static int list_locations(struct kmem_cache *s, char *buf, struct kmem_cache_node *n; if (!map || !alloc_loc_track(&t, PAGE_SIZE / sizeof(struct location), - GFP_TEMPORARY)) { + GFP_KERNEL)) { kfree(map); return sprintf(buf, "Out of memory\n"); } diff --git a/tools/perf/builtin-kmem.c b/tools/perf/builtin-kmem.c index a1497c516d85..24ee68ecdd42 100644 --- a/tools/perf/builtin-kmem.c +++ b/tools/perf/builtin-kmem.c @@ -627,7 +627,6 @@ static const struct { { "GFP_HIGHUSER_MOVABLE", "HUM" }, { "GFP_HIGHUSER", "HU" }, { "GFP_USER", "U" }, - { "GFP_TEMPORARY", "TMP" }, { "GFP_KERNEL_ACCOUNT", "KAC" }, { "GFP_KERNEL", "K" }, { "GFP_NOFS", "NF" }, -- cgit v1.2.3 From 0647169cf9aa441700eb8f23ea49be060626534b Mon Sep 17 00:00:00 2001 From: Andreas Gruenbacher Date: Tue, 19 Sep 2017 12:41:37 +0200 Subject: rhashtable: Documentation tweak Clarify that rhashtable_walk_{stop,start} will not reset the iterator to the beginning of the hash table. Confusion between rhashtable_walk_enter and rhashtable_walk_start has already lead to a bug. Signed-off-by: Andreas Gruenbacher Signed-off-by: David S. Miller --- lib/rhashtable.c | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'lib') diff --git a/lib/rhashtable.c b/lib/rhashtable.c index 707ca5d677c6..ddd7dde87c3c 100644 --- a/lib/rhashtable.c +++ b/lib/rhashtable.c @@ -735,9 +735,9 @@ EXPORT_SYMBOL_GPL(rhashtable_walk_exit); * rhashtable_walk_start - Start a hash table walk * @iter: Hash table iterator * - * Start a hash table walk. Note that we take the RCU lock in all - * cases including when we return an error. So you must always call - * rhashtable_walk_stop to clean up. + * Start a hash table walk at the current iterator position. Note that we take + * the RCU lock in all cases including when we return an error. So you must + * always call rhashtable_walk_stop to clean up. * * Returns zero if successful. * @@ -846,7 +846,8 @@ EXPORT_SYMBOL_GPL(rhashtable_walk_next); * rhashtable_walk_stop - Finish a hash table walk * @iter: Hash table iterator * - * Finish a hash table walk. + * Finish a hash table walk. Does not reset the iterator to the start of the + * hash table. */ void rhashtable_walk_stop(struct rhashtable_iter *iter) __releases(RCU) -- cgit v1.2.3 From a90bcb86ae700c12432446c4aa1819e7b8e172ec Mon Sep 17 00:00:00 2001 From: Petar Penkov Date: Tue, 29 Aug 2017 11:20:32 -0700 Subject: iov_iter: fix page_copy_sane for compound pages Issue is that if the data crosses a page boundary inside a compound page, this check will incorrectly trigger a WARN_ON. To fix this, compute the order using the head of the compound page and adjust the offset to be relative to that head. Fixes: 72e809ed81ed ("iov_iter: sanity checks for copy to/from page primitives") Signed-off-by: Petar Penkov CC: Al Viro CC: Eric Dumazet Signed-off-by: Al Viro --- lib/iov_iter.c | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'lib') diff --git a/lib/iov_iter.c b/lib/iov_iter.c index 52c8dd6d8e82..1c1c06ddc20a 100644 --- a/lib/iov_iter.c +++ b/lib/iov_iter.c @@ -687,8 +687,10 @@ EXPORT_SYMBOL(_copy_from_iter_full_nocache); static inline bool page_copy_sane(struct page *page, size_t offset, size_t n) { - size_t v = n + offset; - if (likely(n <= v && v <= (PAGE_SIZE << compound_order(page)))) + struct page *head = compound_head(page); + size_t v = n + offset + page_address(page) - page_address(head); + + if (likely(n <= v && v <= (PAGE_SIZE << compound_order(head)))) return true; WARN_ON(1); return false; -- cgit v1.2.3 From 432654df90f2a7d8ac8438905208074604ed3e00 Mon Sep 17 00:00:00 2001 From: Helge Deller Date: Mon, 11 Sep 2017 21:41:43 +0200 Subject: parisc: Fix too large frame size warnings The parisc architecture has larger stack frames than most other architectures on 32-bit kernels. Increase the maximum allowed stack frame to 1280 bytes for parisc to avoid warnings in the do_sys_poll() and pat_memconfig() functions. Signed-off-by: Helge Deller --- lib/Kconfig.debug | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'lib') diff --git a/lib/Kconfig.debug b/lib/Kconfig.debug index b19c491cbc4e..2689b7c50c52 100644 --- a/lib/Kconfig.debug +++ b/lib/Kconfig.debug @@ -219,7 +219,8 @@ config FRAME_WARN range 0 8192 default 0 if KASAN default 2048 if GCC_PLUGIN_LATENT_ENTROPY - default 1024 if !64BIT + default 1280 if (!64BIT && PARISC) + default 1024 if (!64BIT && !PARISC) default 2048 if 64BIT help Tell gcc to warn at build time for stack frames larger than this. -- cgit v1.2.3