From e5b18b0d8067098813447ae5b2c59eea9be91e2c Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Sun, 20 Apr 2025 15:31:34 +0200 Subject: [PATCH 1/9] amd-pstate Signed-off-by: Peter Jung --- arch/x86/include/asm/msr-index.h | 20 +- arch/x86/kernel/acpi/cppc.c | 4 +- drivers/cpufreq/amd-pstate-trace.h | 13 +- drivers/cpufreq/amd-pstate-ut.c | 211 ++++----- drivers/cpufreq/amd-pstate.c | 674 ++++++++++++----------------- drivers/cpufreq/amd-pstate.h | 63 +-- include/linux/cpufreq.h | 3 + include/linux/sched/topology.h | 6 + kernel/sched/debug.c | 4 + kernel/sched/fair.c | 5 +- kernel/sched/topology.c | 58 +++ 11 files changed, 507 insertions(+), 554 deletions(-) diff --git a/arch/x86/include/asm/msr-index.h b/arch/x86/include/asm/msr-index.h index 72765b2fe0d8..fc2634cc48fd 100644 --- a/arch/x86/include/asm/msr-index.h +++ b/arch/x86/include/asm/msr-index.h @@ -701,15 +701,17 @@ #define MSR_AMD_CPPC_REQ 0xc00102b3 #define MSR_AMD_CPPC_STATUS 0xc00102b4 -#define AMD_CPPC_LOWEST_PERF(x) (((x) >> 0) & 0xff) -#define AMD_CPPC_LOWNONLIN_PERF(x) (((x) >> 8) & 0xff) -#define AMD_CPPC_NOMINAL_PERF(x) (((x) >> 16) & 0xff) -#define AMD_CPPC_HIGHEST_PERF(x) (((x) >> 24) & 0xff) - -#define AMD_CPPC_MAX_PERF(x) (((x) & 0xff) << 0) -#define AMD_CPPC_MIN_PERF(x) (((x) & 0xff) << 8) -#define AMD_CPPC_DES_PERF(x) (((x) & 0xff) << 16) -#define AMD_CPPC_ENERGY_PERF_PREF(x) (((x) & 0xff) << 24) +/* Masks for use with MSR_AMD_CPPC_CAP1 */ +#define AMD_CPPC_LOWEST_PERF_MASK GENMASK(7, 0) +#define AMD_CPPC_LOWNONLIN_PERF_MASK GENMASK(15, 8) +#define AMD_CPPC_NOMINAL_PERF_MASK GENMASK(23, 16) +#define AMD_CPPC_HIGHEST_PERF_MASK GENMASK(31, 24) + +/* Masks for use with MSR_AMD_CPPC_REQ */ +#define AMD_CPPC_MAX_PERF_MASK GENMASK(7, 0) +#define AMD_CPPC_MIN_PERF_MASK GENMASK(15, 8) +#define AMD_CPPC_DES_PERF_MASK GENMASK(23, 16) +#define AMD_CPPC_EPP_PERF_MASK GENMASK(31, 24) /* AMD Performance Counter Global Status and Control MSRs */ #define MSR_AMD64_PERF_CNTR_GLOBAL_STATUS 0xc0000300 diff --git a/arch/x86/kernel/acpi/cppc.c b/arch/x86/kernel/acpi/cppc.c index d745dd586303..77bfb846490c 100644 --- a/arch/x86/kernel/acpi/cppc.c +++ b/arch/x86/kernel/acpi/cppc.c @@ -4,6 +4,8 @@ * Copyright (c) 2016, Intel Corporation. */ +#include + #include #include #include @@ -149,7 +151,7 @@ int amd_get_highest_perf(unsigned int cpu, u32 *highest_perf) if (ret) goto out; - val = AMD_CPPC_HIGHEST_PERF(val); + val = FIELD_GET(AMD_CPPC_HIGHEST_PERF_MASK, val); } else { ret = cppc_get_highest_perf(cpu, &val); if (ret) diff --git a/drivers/cpufreq/amd-pstate-trace.h b/drivers/cpufreq/amd-pstate-trace.h index f457d4af2c62..32e1bdc588c5 100644 --- a/drivers/cpufreq/amd-pstate-trace.h +++ b/drivers/cpufreq/amd-pstate-trace.h @@ -90,7 +90,8 @@ TRACE_EVENT(amd_pstate_epp_perf, u8 epp, u8 min_perf, u8 max_perf, - bool boost + bool boost, + bool changed ), TP_ARGS(cpu_id, @@ -98,7 +99,8 @@ TRACE_EVENT(amd_pstate_epp_perf, epp, min_perf, max_perf, - boost), + boost, + changed), TP_STRUCT__entry( __field(unsigned int, cpu_id) @@ -107,6 +109,7 @@ TRACE_EVENT(amd_pstate_epp_perf, __field(u8, min_perf) __field(u8, max_perf) __field(bool, boost) + __field(bool, changed) ), TP_fast_assign( @@ -116,15 +119,17 @@ TRACE_EVENT(amd_pstate_epp_perf, __entry->min_perf = min_perf; __entry->max_perf = max_perf; __entry->boost = boost; + __entry->changed = changed; ), - TP_printk("cpu%u: [%hhu<->%hhu]/%hhu, epp=%hhu, boost=%u", + TP_printk("cpu%u: [%hhu<->%hhu]/%hhu, epp=%hhu, boost=%u, changed=%u", (unsigned int)__entry->cpu_id, (u8)__entry->min_perf, (u8)__entry->max_perf, (u8)__entry->highest_perf, (u8)__entry->epp, - (bool)__entry->boost + (bool)__entry->boost, + (bool)__entry->changed ) ); diff --git a/drivers/cpufreq/amd-pstate-ut.c b/drivers/cpufreq/amd-pstate-ut.c index 3a0a380c3590..e671bc7d1550 100644 --- a/drivers/cpufreq/amd-pstate-ut.c +++ b/drivers/cpufreq/amd-pstate-ut.c @@ -22,39 +22,31 @@ #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include #include #include #include #include +#include #include #include "amd-pstate.h" -/* - * Abbreviations: - * amd_pstate_ut: used as a shortform for AMD P-State unit test. - * It helps to keep variable names smaller, simpler - */ -enum amd_pstate_ut_result { - AMD_PSTATE_UT_RESULT_PASS, - AMD_PSTATE_UT_RESULT_FAIL, -}; struct amd_pstate_ut_struct { const char *name; - void (*func)(u32 index); - enum amd_pstate_ut_result result; + int (*func)(u32 index); }; /* * Kernel module for testing the AMD P-State unit test */ -static void amd_pstate_ut_acpi_cpc_valid(u32 index); -static void amd_pstate_ut_check_enabled(u32 index); -static void amd_pstate_ut_check_perf(u32 index); -static void amd_pstate_ut_check_freq(u32 index); -static void amd_pstate_ut_check_driver(u32 index); +static int amd_pstate_ut_acpi_cpc_valid(u32 index); +static int amd_pstate_ut_check_enabled(u32 index); +static int amd_pstate_ut_check_perf(u32 index); +static int amd_pstate_ut_check_freq(u32 index); +static int amd_pstate_ut_check_driver(u32 index); static struct amd_pstate_ut_struct amd_pstate_ut_cases[] = { {"amd_pstate_ut_acpi_cpc_valid", amd_pstate_ut_acpi_cpc_valid }, @@ -77,71 +69,67 @@ static bool get_shared_mem(void) /* * check the _CPC object is present in SBIOS. */ -static void amd_pstate_ut_acpi_cpc_valid(u32 index) +static int amd_pstate_ut_acpi_cpc_valid(u32 index) { - if (acpi_cpc_valid()) - amd_pstate_ut_cases[index].result = AMD_PSTATE_UT_RESULT_PASS; - else { - amd_pstate_ut_cases[index].result = AMD_PSTATE_UT_RESULT_FAIL; + if (!acpi_cpc_valid()) { pr_err("%s the _CPC object is not present in SBIOS!\n", __func__); + return -EINVAL; } + + return 0; } -static void amd_pstate_ut_pstate_enable(u32 index) +/* + * check if amd pstate is enabled + */ +static int amd_pstate_ut_check_enabled(u32 index) { - int ret = 0; u64 cppc_enable = 0; + int ret; + + if (get_shared_mem()) + return 0; ret = rdmsrl_safe(MSR_AMD_CPPC_ENABLE, &cppc_enable); if (ret) { - amd_pstate_ut_cases[index].result = AMD_PSTATE_UT_RESULT_FAIL; pr_err("%s rdmsrl_safe MSR_AMD_CPPC_ENABLE ret=%d error!\n", __func__, ret); - return; + return ret; } - if (cppc_enable) - amd_pstate_ut_cases[index].result = AMD_PSTATE_UT_RESULT_PASS; - else { - amd_pstate_ut_cases[index].result = AMD_PSTATE_UT_RESULT_FAIL; + + if (!cppc_enable) { pr_err("%s amd pstate must be enabled!\n", __func__); + return -EINVAL; } -} -/* - * check if amd pstate is enabled - */ -static void amd_pstate_ut_check_enabled(u32 index) -{ - if (get_shared_mem()) - amd_pstate_ut_cases[index].result = AMD_PSTATE_UT_RESULT_PASS; - else - amd_pstate_ut_pstate_enable(index); + return 0; } /* * check if performance values are reasonable. * highest_perf >= nominal_perf > lowest_nonlinear_perf > lowest_perf > 0 */ -static void amd_pstate_ut_check_perf(u32 index) +static int amd_pstate_ut_check_perf(u32 index) { int cpu = 0, ret = 0; u32 highest_perf = 0, nominal_perf = 0, lowest_nonlinear_perf = 0, lowest_perf = 0; u64 cap1 = 0; struct cppc_perf_caps cppc_perf; - struct cpufreq_policy *policy = NULL; - struct amd_cpudata *cpudata = NULL; + union perf_cached cur_perf; + + for_each_online_cpu(cpu) { + struct cpufreq_policy *policy __free(put_cpufreq_policy) = NULL; + struct amd_cpudata *cpudata; - for_each_possible_cpu(cpu) { policy = cpufreq_cpu_get(cpu); if (!policy) - break; + continue; cpudata = policy->driver_data; if (get_shared_mem()) { ret = cppc_get_perf_caps(cpu, &cppc_perf); if (ret) { - amd_pstate_ut_cases[index].result = AMD_PSTATE_UT_RESULT_FAIL; pr_err("%s cppc_get_perf_caps ret=%d error!\n", __func__, ret); - goto skip_test; + return ret; } highest_perf = cppc_perf.highest_perf; @@ -151,50 +139,44 @@ static void amd_pstate_ut_check_perf(u32 index) } else { ret = rdmsrl_safe_on_cpu(cpu, MSR_AMD_CPPC_CAP1, &cap1); if (ret) { - amd_pstate_ut_cases[index].result = AMD_PSTATE_UT_RESULT_FAIL; pr_err("%s read CPPC_CAP1 ret=%d error!\n", __func__, ret); - goto skip_test; + return ret; } - highest_perf = AMD_CPPC_HIGHEST_PERF(cap1); - nominal_perf = AMD_CPPC_NOMINAL_PERF(cap1); - lowest_nonlinear_perf = AMD_CPPC_LOWNONLIN_PERF(cap1); - lowest_perf = AMD_CPPC_LOWEST_PERF(cap1); + highest_perf = FIELD_GET(AMD_CPPC_HIGHEST_PERF_MASK, cap1); + nominal_perf = FIELD_GET(AMD_CPPC_NOMINAL_PERF_MASK, cap1); + lowest_nonlinear_perf = FIELD_GET(AMD_CPPC_LOWNONLIN_PERF_MASK, cap1); + lowest_perf = FIELD_GET(AMD_CPPC_LOWEST_PERF_MASK, cap1); } - if (highest_perf != READ_ONCE(cpudata->highest_perf) && !cpudata->hw_prefcore) { + cur_perf = READ_ONCE(cpudata->perf); + if (highest_perf != cur_perf.highest_perf && !cpudata->hw_prefcore) { pr_err("%s cpu%d highest=%d %d highest perf doesn't match\n", - __func__, cpu, highest_perf, cpudata->highest_perf); - goto skip_test; + __func__, cpu, highest_perf, cur_perf.highest_perf); + return -EINVAL; } - if ((nominal_perf != READ_ONCE(cpudata->nominal_perf)) || - (lowest_nonlinear_perf != READ_ONCE(cpudata->lowest_nonlinear_perf)) || - (lowest_perf != READ_ONCE(cpudata->lowest_perf))) { - amd_pstate_ut_cases[index].result = AMD_PSTATE_UT_RESULT_FAIL; + if (nominal_perf != cur_perf.nominal_perf || + (lowest_nonlinear_perf != cur_perf.lowest_nonlinear_perf) || + (lowest_perf != cur_perf.lowest_perf)) { pr_err("%s cpu%d nominal=%d %d lowest_nonlinear=%d %d lowest=%d %d, they should be equal!\n", - __func__, cpu, nominal_perf, cpudata->nominal_perf, - lowest_nonlinear_perf, cpudata->lowest_nonlinear_perf, - lowest_perf, cpudata->lowest_perf); - goto skip_test; + __func__, cpu, nominal_perf, cur_perf.nominal_perf, + lowest_nonlinear_perf, cur_perf.lowest_nonlinear_perf, + lowest_perf, cur_perf.lowest_perf); + return -EINVAL; } if (!((highest_perf >= nominal_perf) && (nominal_perf > lowest_nonlinear_perf) && - (lowest_nonlinear_perf > lowest_perf) && + (lowest_nonlinear_perf >= lowest_perf) && (lowest_perf > 0))) { - amd_pstate_ut_cases[index].result = AMD_PSTATE_UT_RESULT_FAIL; pr_err("%s cpu%d highest=%d >= nominal=%d > lowest_nonlinear=%d > lowest=%d > 0, the formula is incorrect!\n", __func__, cpu, highest_perf, nominal_perf, lowest_nonlinear_perf, lowest_perf); - goto skip_test; + return -EINVAL; } - cpufreq_cpu_put(policy); } - amd_pstate_ut_cases[index].result = AMD_PSTATE_UT_RESULT_PASS; - return; -skip_test: - cpufreq_cpu_put(policy); + return 0; } /* @@ -202,59 +184,50 @@ static void amd_pstate_ut_check_perf(u32 index) * max_freq >= nominal_freq > lowest_nonlinear_freq > min_freq > 0 * check max freq when set support boost mode. */ -static void amd_pstate_ut_check_freq(u32 index) +static int amd_pstate_ut_check_freq(u32 index) { int cpu = 0; - struct cpufreq_policy *policy = NULL; - struct amd_cpudata *cpudata = NULL; - for_each_possible_cpu(cpu) { + for_each_online_cpu(cpu) { + struct cpufreq_policy *policy __free(put_cpufreq_policy) = NULL; + struct amd_cpudata *cpudata; + policy = cpufreq_cpu_get(cpu); if (!policy) - break; + continue; cpudata = policy->driver_data; - if (!((cpudata->max_freq >= cpudata->nominal_freq) && + if (!((policy->cpuinfo.max_freq >= cpudata->nominal_freq) && (cpudata->nominal_freq > cpudata->lowest_nonlinear_freq) && - (cpudata->lowest_nonlinear_freq > cpudata->min_freq) && - (cpudata->min_freq > 0))) { - amd_pstate_ut_cases[index].result = AMD_PSTATE_UT_RESULT_FAIL; + (cpudata->lowest_nonlinear_freq >= policy->cpuinfo.min_freq) && + (policy->cpuinfo.min_freq > 0))) { pr_err("%s cpu%d max=%d >= nominal=%d > lowest_nonlinear=%d > min=%d > 0, the formula is incorrect!\n", - __func__, cpu, cpudata->max_freq, cpudata->nominal_freq, - cpudata->lowest_nonlinear_freq, cpudata->min_freq); - goto skip_test; + __func__, cpu, policy->cpuinfo.max_freq, cpudata->nominal_freq, + cpudata->lowest_nonlinear_freq, policy->cpuinfo.min_freq); + return -EINVAL; } if (cpudata->lowest_nonlinear_freq != policy->min) { - amd_pstate_ut_cases[index].result = AMD_PSTATE_UT_RESULT_FAIL; pr_err("%s cpu%d cpudata_lowest_nonlinear_freq=%d policy_min=%d, they should be equal!\n", __func__, cpu, cpudata->lowest_nonlinear_freq, policy->min); - goto skip_test; + return -EINVAL; } if (cpudata->boost_supported) { - if ((policy->max == cpudata->max_freq) || - (policy->max == cpudata->nominal_freq)) - amd_pstate_ut_cases[index].result = AMD_PSTATE_UT_RESULT_PASS; - else { - amd_pstate_ut_cases[index].result = AMD_PSTATE_UT_RESULT_FAIL; + if ((policy->max != policy->cpuinfo.max_freq) && + (policy->max != cpudata->nominal_freq)) { pr_err("%s cpu%d policy_max=%d should be equal cpu_max=%d or cpu_nominal=%d !\n", - __func__, cpu, policy->max, cpudata->max_freq, + __func__, cpu, policy->max, policy->cpuinfo.max_freq, cpudata->nominal_freq); - goto skip_test; + return -EINVAL; } } else { - amd_pstate_ut_cases[index].result = AMD_PSTATE_UT_RESULT_FAIL; pr_err("%s cpu%d must support boost!\n", __func__, cpu); - goto skip_test; + return -EINVAL; } - cpufreq_cpu_put(policy); } - amd_pstate_ut_cases[index].result = AMD_PSTATE_UT_RESULT_PASS; - return; -skip_test: - cpufreq_cpu_put(policy); + return 0; } static int amd_pstate_set_mode(enum amd_pstate_mode mode) @@ -266,32 +239,28 @@ static int amd_pstate_set_mode(enum amd_pstate_mode mode) return amd_pstate_update_status(mode_str, strlen(mode_str)); } -static void amd_pstate_ut_check_driver(u32 index) +static int amd_pstate_ut_check_driver(u32 index) { enum amd_pstate_mode mode1, mode2 = AMD_PSTATE_DISABLE; - int ret; for (mode1 = AMD_PSTATE_DISABLE; mode1 < AMD_PSTATE_MAX; mode1++) { - ret = amd_pstate_set_mode(mode1); + int ret = amd_pstate_set_mode(mode1); if (ret) - goto out; + return ret; for (mode2 = AMD_PSTATE_DISABLE; mode2 < AMD_PSTATE_MAX; mode2++) { if (mode1 == mode2) continue; ret = amd_pstate_set_mode(mode2); - if (ret) - goto out; + if (ret) { + pr_err("%s: failed to update status for %s->%s\n", __func__, + amd_pstate_get_mode_string(mode1), + amd_pstate_get_mode_string(mode2)); + return ret; + } } } -out: - if (ret) - pr_warn("%s: failed to update status for %s->%s: %d\n", __func__, - amd_pstate_get_mode_string(mode1), - amd_pstate_get_mode_string(mode2), ret); - - amd_pstate_ut_cases[index].result = ret ? - AMD_PSTATE_UT_RESULT_FAIL : - AMD_PSTATE_UT_RESULT_PASS; + + return 0; } static int __init amd_pstate_ut_init(void) @@ -299,16 +268,12 @@ static int __init amd_pstate_ut_init(void) u32 i = 0, arr_size = ARRAY_SIZE(amd_pstate_ut_cases); for (i = 0; i < arr_size; i++) { - amd_pstate_ut_cases[i].func(i); - switch (amd_pstate_ut_cases[i].result) { - case AMD_PSTATE_UT_RESULT_PASS: + int ret = amd_pstate_ut_cases[i].func(i); + + if (ret) + pr_err("%-4d %-20s\t fail: %d!\n", i+1, amd_pstate_ut_cases[i].name, ret); + else pr_info("%-4d %-20s\t success!\n", i+1, amd_pstate_ut_cases[i].name); - break; - case AMD_PSTATE_UT_RESULT_FAIL: - default: - pr_info("%-4d %-20s\t fail!\n", i+1, amd_pstate_ut_cases[i].name); - break; - } } return 0; diff --git a/drivers/cpufreq/amd-pstate.c b/drivers/cpufreq/amd-pstate.c index 1b26845703f6..f62dd09c1e4c 100644 --- a/drivers/cpufreq/amd-pstate.c +++ b/drivers/cpufreq/amd-pstate.c @@ -85,15 +85,9 @@ static struct cpufreq_driver *current_pstate_driver; static struct cpufreq_driver amd_pstate_driver; static struct cpufreq_driver amd_pstate_epp_driver; static int cppc_state = AMD_PSTATE_UNDEFINED; -static bool cppc_enabled; static bool amd_pstate_prefcore = true; static struct quirk_entry *quirks; -#define AMD_CPPC_MAX_PERF_MASK GENMASK(7, 0) -#define AMD_CPPC_MIN_PERF_MASK GENMASK(15, 8) -#define AMD_CPPC_DES_PERF_MASK GENMASK(23, 16) -#define AMD_CPPC_EPP_PERF_MASK GENMASK(31, 24) - /* * AMD Energy Preference Performance (EPP) * The EPP is used in the CCLK DPM controller to drive @@ -142,6 +136,19 @@ static struct quirk_entry quirk_amd_7k62 = { .lowest_freq = 550, }; +static inline u8 freq_to_perf(union perf_cached perf, u32 nominal_freq, unsigned int freq_val) +{ + u32 perf_val = DIV_ROUND_UP_ULL((u64)freq_val * perf.nominal_perf, nominal_freq); + + return (u8)clamp(perf_val, perf.lowest_perf, perf.highest_perf); +} + +static inline u32 perf_to_freq(union perf_cached perf, u32 nominal_freq, u8 perf_val) +{ + return DIV_ROUND_UP_ULL((u64)nominal_freq * perf_val, + perf.nominal_perf); +} + static int __init dmi_matched_7k62_bios_bug(const struct dmi_system_id *dmi) { /** @@ -183,7 +190,6 @@ static inline int get_mode_idx_from_str(const char *str, size_t size) return -EINVAL; } -static DEFINE_MUTEX(amd_pstate_limits_lock); static DEFINE_MUTEX(amd_pstate_driver_lock); static u8 msr_get_epp(struct amd_cpudata *cpudata) @@ -221,9 +227,10 @@ static u8 shmem_get_epp(struct amd_cpudata *cpudata) return FIELD_GET(AMD_CPPC_EPP_PERF_MASK, epp); } -static int msr_update_perf(struct amd_cpudata *cpudata, u8 min_perf, +static int msr_update_perf(struct cpufreq_policy *policy, u8 min_perf, u8 des_perf, u8 max_perf, u8 epp, bool fast_switch) { + struct amd_cpudata *cpudata = policy->driver_data; u64 value, prev; value = prev = READ_ONCE(cpudata->cppc_req_cached); @@ -235,6 +242,18 @@ static int msr_update_perf(struct amd_cpudata *cpudata, u8 min_perf, value |= FIELD_PREP(AMD_CPPC_MIN_PERF_MASK, min_perf); value |= FIELD_PREP(AMD_CPPC_EPP_PERF_MASK, epp); + if (trace_amd_pstate_epp_perf_enabled()) { + union perf_cached perf = READ_ONCE(cpudata->perf); + + trace_amd_pstate_epp_perf(cpudata->cpu, + perf.highest_perf, + epp, + min_perf, + max_perf, + policy->boost_enabled, + value != prev); + } + if (value == prev) return 0; @@ -249,24 +268,24 @@ static int msr_update_perf(struct amd_cpudata *cpudata, u8 min_perf, } WRITE_ONCE(cpudata->cppc_req_cached, value); - WRITE_ONCE(cpudata->epp_cached, epp); return 0; } DEFINE_STATIC_CALL(amd_pstate_update_perf, msr_update_perf); -static inline int amd_pstate_update_perf(struct amd_cpudata *cpudata, +static inline int amd_pstate_update_perf(struct cpufreq_policy *policy, u8 min_perf, u8 des_perf, u8 max_perf, u8 epp, bool fast_switch) { - return static_call(amd_pstate_update_perf)(cpudata, min_perf, des_perf, + return static_call(amd_pstate_update_perf)(policy, min_perf, des_perf, max_perf, epp, fast_switch); } -static int msr_set_epp(struct amd_cpudata *cpudata, u8 epp) +static int msr_set_epp(struct cpufreq_policy *policy, u8 epp) { + struct amd_cpudata *cpudata = policy->driver_data; u64 value, prev; int ret; @@ -274,6 +293,19 @@ static int msr_set_epp(struct amd_cpudata *cpudata, u8 epp) value &= ~AMD_CPPC_EPP_PERF_MASK; value |= FIELD_PREP(AMD_CPPC_EPP_PERF_MASK, epp); + if (trace_amd_pstate_epp_perf_enabled()) { + union perf_cached perf = cpudata->perf; + + trace_amd_pstate_epp_perf(cpudata->cpu, perf.highest_perf, + epp, + FIELD_GET(AMD_CPPC_MIN_PERF_MASK, + cpudata->cppc_req_cached), + FIELD_GET(AMD_CPPC_MAX_PERF_MASK, + cpudata->cppc_req_cached), + policy->boost_enabled, + value != prev); + } + if (value == prev) return 0; @@ -284,7 +316,6 @@ static int msr_set_epp(struct amd_cpudata *cpudata, u8 epp) } /* update both so that msr_update_perf() can effectively check */ - WRITE_ONCE(cpudata->epp_cached, epp); WRITE_ONCE(cpudata->cppc_req_cached, value); return ret; @@ -292,17 +323,35 @@ static int msr_set_epp(struct amd_cpudata *cpudata, u8 epp) DEFINE_STATIC_CALL(amd_pstate_set_epp, msr_set_epp); -static inline int amd_pstate_set_epp(struct amd_cpudata *cpudata, u8 epp) +static inline int amd_pstate_set_epp(struct cpufreq_policy *policy, u8 epp) { - return static_call(amd_pstate_set_epp)(cpudata, epp); + return static_call(amd_pstate_set_epp)(policy, epp); } -static int shmem_set_epp(struct amd_cpudata *cpudata, u8 epp) +static int shmem_set_epp(struct cpufreq_policy *policy, u8 epp) { - int ret; + struct amd_cpudata *cpudata = policy->driver_data; struct cppc_perf_ctrls perf_ctrls; + u8 epp_cached; + u64 value; + int ret; - if (epp == cpudata->epp_cached) + + epp_cached = FIELD_GET(AMD_CPPC_EPP_PERF_MASK, cpudata->cppc_req_cached); + if (trace_amd_pstate_epp_perf_enabled()) { + union perf_cached perf = cpudata->perf; + + trace_amd_pstate_epp_perf(cpudata->cpu, perf.highest_perf, + epp, + FIELD_GET(AMD_CPPC_MIN_PERF_MASK, + cpudata->cppc_req_cached), + FIELD_GET(AMD_CPPC_MAX_PERF_MASK, + cpudata->cppc_req_cached), + policy->boost_enabled, + epp != epp_cached); + } + + if (epp == epp_cached) return 0; perf_ctrls.energy_perf = epp; @@ -311,106 +360,35 @@ static int shmem_set_epp(struct amd_cpudata *cpudata, u8 epp) pr_debug("failed to set energy perf value (%d)\n", ret); return ret; } - WRITE_ONCE(cpudata->epp_cached, epp); - - return ret; -} - -static int amd_pstate_set_energy_pref_index(struct cpufreq_policy *policy, - int pref_index) -{ - struct amd_cpudata *cpudata = policy->driver_data; - u8 epp; - - if (!pref_index) - epp = cpudata->epp_default; - else - epp = epp_values[pref_index]; - - if (epp > 0 && cpudata->policy == CPUFREQ_POLICY_PERFORMANCE) { - pr_debug("EPP cannot be set under performance policy\n"); - return -EBUSY; - } - if (trace_amd_pstate_epp_perf_enabled()) { - trace_amd_pstate_epp_perf(cpudata->cpu, cpudata->highest_perf, - epp, - FIELD_GET(AMD_CPPC_MIN_PERF_MASK, cpudata->cppc_req_cached), - FIELD_GET(AMD_CPPC_MAX_PERF_MASK, cpudata->cppc_req_cached), - policy->boost_enabled); - } + value = READ_ONCE(cpudata->cppc_req_cached); + value &= ~AMD_CPPC_EPP_PERF_MASK; + value |= FIELD_PREP(AMD_CPPC_EPP_PERF_MASK, epp); + WRITE_ONCE(cpudata->cppc_req_cached, value); - return amd_pstate_set_epp(cpudata, epp); + return ret; } -static inline int msr_cppc_enable(bool enable) +static inline int msr_cppc_enable(struct cpufreq_policy *policy) { - int ret, cpu; - unsigned long logical_proc_id_mask = 0; - - /* - * MSR_AMD_CPPC_ENABLE is write-once, once set it cannot be cleared. - */ - if (!enable) - return 0; - - if (enable == cppc_enabled) - return 0; - - for_each_present_cpu(cpu) { - unsigned long logical_id = topology_logical_package_id(cpu); - - if (test_bit(logical_id, &logical_proc_id_mask)) - continue; - - set_bit(logical_id, &logical_proc_id_mask); - - ret = wrmsrl_safe_on_cpu(cpu, MSR_AMD_CPPC_ENABLE, - enable); - if (ret) - return ret; - } - - cppc_enabled = enable; - return 0; + return wrmsrl_safe_on_cpu(policy->cpu, MSR_AMD_CPPC_ENABLE, 1); } -static int shmem_cppc_enable(bool enable) +static int shmem_cppc_enable(struct cpufreq_policy *policy) { - int cpu, ret = 0; - struct cppc_perf_ctrls perf_ctrls; - - if (enable == cppc_enabled) - return 0; - - for_each_present_cpu(cpu) { - ret = cppc_set_enable(cpu, enable); - if (ret) - return ret; - - /* Enable autonomous mode for EPP */ - if (cppc_state == AMD_PSTATE_ACTIVE) { - /* Set desired perf as zero to allow EPP firmware control */ - perf_ctrls.desired_perf = 0; - ret = cppc_set_perf(cpu, &perf_ctrls); - if (ret) - return ret; - } - } - - cppc_enabled = enable; - return ret; + return cppc_set_enable(policy->cpu, 1); } DEFINE_STATIC_CALL(amd_pstate_cppc_enable, msr_cppc_enable); -static inline int amd_pstate_cppc_enable(bool enable) +static inline int amd_pstate_cppc_enable(struct cpufreq_policy *policy) { - return static_call(amd_pstate_cppc_enable)(enable); + return static_call(amd_pstate_cppc_enable)(policy); } static int msr_init_perf(struct amd_cpudata *cpudata) { + union perf_cached perf = READ_ONCE(cpudata->perf); u64 cap1, numerator; int ret = rdmsrl_safe_on_cpu(cpudata->cpu, MSR_AMD_CPPC_CAP1, @@ -422,19 +400,22 @@ static int msr_init_perf(struct amd_cpudata *cpudata) if (ret) return ret; - WRITE_ONCE(cpudata->highest_perf, numerator); - WRITE_ONCE(cpudata->max_limit_perf, numerator); - WRITE_ONCE(cpudata->nominal_perf, AMD_CPPC_NOMINAL_PERF(cap1)); - WRITE_ONCE(cpudata->lowest_nonlinear_perf, AMD_CPPC_LOWNONLIN_PERF(cap1)); - WRITE_ONCE(cpudata->lowest_perf, AMD_CPPC_LOWEST_PERF(cap1)); - WRITE_ONCE(cpudata->prefcore_ranking, AMD_CPPC_HIGHEST_PERF(cap1)); - WRITE_ONCE(cpudata->min_limit_perf, AMD_CPPC_LOWEST_PERF(cap1)); + perf.highest_perf = numerator; + perf.max_limit_perf = numerator; + perf.min_limit_perf = FIELD_GET(AMD_CPPC_LOWEST_PERF_MASK, cap1); + perf.nominal_perf = FIELD_GET(AMD_CPPC_NOMINAL_PERF_MASK, cap1); + perf.lowest_nonlinear_perf = FIELD_GET(AMD_CPPC_LOWNONLIN_PERF_MASK, cap1); + perf.lowest_perf = FIELD_GET(AMD_CPPC_LOWEST_PERF_MASK, cap1); + WRITE_ONCE(cpudata->perf, perf); + WRITE_ONCE(cpudata->prefcore_ranking, FIELD_GET(AMD_CPPC_HIGHEST_PERF_MASK, cap1)); + return 0; } static int shmem_init_perf(struct amd_cpudata *cpudata) { struct cppc_perf_caps cppc_perf; + union perf_cached perf = READ_ONCE(cpudata->perf); u64 numerator; int ret = cppc_get_perf_caps(cpudata->cpu, &cppc_perf); @@ -445,14 +426,14 @@ static int shmem_init_perf(struct amd_cpudata *cpudata) if (ret) return ret; - WRITE_ONCE(cpudata->highest_perf, numerator); - WRITE_ONCE(cpudata->max_limit_perf, numerator); - WRITE_ONCE(cpudata->nominal_perf, cppc_perf.nominal_perf); - WRITE_ONCE(cpudata->lowest_nonlinear_perf, - cppc_perf.lowest_nonlinear_perf); - WRITE_ONCE(cpudata->lowest_perf, cppc_perf.lowest_perf); + perf.highest_perf = numerator; + perf.max_limit_perf = numerator; + perf.min_limit_perf = cppc_perf.lowest_perf; + perf.nominal_perf = cppc_perf.nominal_perf; + perf.lowest_nonlinear_perf = cppc_perf.lowest_nonlinear_perf; + perf.lowest_perf = cppc_perf.lowest_perf; + WRITE_ONCE(cpudata->perf, perf); WRITE_ONCE(cpudata->prefcore_ranking, cppc_perf.highest_perf); - WRITE_ONCE(cpudata->min_limit_perf, cppc_perf.lowest_perf); if (cppc_state == AMD_PSTATE_ACTIVE) return 0; @@ -479,23 +460,56 @@ static inline int amd_pstate_init_perf(struct amd_cpudata *cpudata) return static_call(amd_pstate_init_perf)(cpudata); } -static int shmem_update_perf(struct amd_cpudata *cpudata, u8 min_perf, +static int shmem_update_perf(struct cpufreq_policy *policy, u8 min_perf, u8 des_perf, u8 max_perf, u8 epp, bool fast_switch) { + struct amd_cpudata *cpudata = policy->driver_data; struct cppc_perf_ctrls perf_ctrls; + u64 value, prev; + int ret; if (cppc_state == AMD_PSTATE_ACTIVE) { - int ret = shmem_set_epp(cpudata, epp); + int ret = shmem_set_epp(policy, epp); if (ret) return ret; } + value = prev = READ_ONCE(cpudata->cppc_req_cached); + + value &= ~(AMD_CPPC_MAX_PERF_MASK | AMD_CPPC_MIN_PERF_MASK | + AMD_CPPC_DES_PERF_MASK | AMD_CPPC_EPP_PERF_MASK); + value |= FIELD_PREP(AMD_CPPC_MAX_PERF_MASK, max_perf); + value |= FIELD_PREP(AMD_CPPC_DES_PERF_MASK, des_perf); + value |= FIELD_PREP(AMD_CPPC_MIN_PERF_MASK, min_perf); + value |= FIELD_PREP(AMD_CPPC_EPP_PERF_MASK, epp); + + if (trace_amd_pstate_epp_perf_enabled()) { + union perf_cached perf = READ_ONCE(cpudata->perf); + + trace_amd_pstate_epp_perf(cpudata->cpu, + perf.highest_perf, + epp, + min_perf, + max_perf, + policy->boost_enabled, + value != prev); + } + + if (value == prev) + return 0; + perf_ctrls.max_perf = max_perf; perf_ctrls.min_perf = min_perf; perf_ctrls.desired_perf = des_perf; - return cppc_set_perf(cpudata->cpu, &perf_ctrls); + ret = cppc_set_perf(cpudata->cpu, &perf_ctrls); + if (ret) + return ret; + + WRITE_ONCE(cpudata->cppc_req_cached, value); + + return 0; } static inline bool amd_pstate_sample(struct amd_cpudata *cpudata) @@ -534,17 +548,15 @@ static inline bool amd_pstate_sample(struct amd_cpudata *cpudata) static void amd_pstate_update(struct amd_cpudata *cpudata, u8 min_perf, u8 des_perf, u8 max_perf, bool fast_switch, int gov_flags) { - unsigned long max_freq; - struct cpufreq_policy *policy = cpufreq_cpu_get(cpudata->cpu); - u8 nominal_perf = READ_ONCE(cpudata->nominal_perf); + struct cpufreq_policy *policy __free(put_cpufreq_policy) = cpufreq_cpu_get(cpudata->cpu); + union perf_cached perf = READ_ONCE(cpudata->perf); if (!policy) return; des_perf = clamp_t(u8, des_perf, min_perf, max_perf); - max_freq = READ_ONCE(cpudata->max_limit_freq); - policy->cur = div_u64(des_perf * max_freq, max_perf); + policy->cur = perf_to_freq(perf, cpudata->nominal_freq, des_perf); if ((cppc_state == AMD_PSTATE_GUIDED) && (gov_flags & CPUFREQ_GOV_DYNAMIC_SWITCHING)) { min_perf = des_perf; @@ -553,7 +565,7 @@ static void amd_pstate_update(struct amd_cpudata *cpudata, u8 min_perf, /* limit the max perf when core performance boost feature is disabled */ if (!cpudata->boost_supported) - max_perf = min_t(u8, nominal_perf, max_perf); + max_perf = min_t(u8, perf.nominal_perf, max_perf); if (trace_amd_pstate_perf_enabled() && amd_pstate_sample(cpudata)) { trace_amd_pstate_perf(min_perf, des_perf, max_perf, cpudata->freq, @@ -561,9 +573,7 @@ static void amd_pstate_update(struct amd_cpudata *cpudata, u8 min_perf, cpudata->cpu, fast_switch); } - amd_pstate_update_perf(cpudata, min_perf, des_perf, max_perf, 0, fast_switch); - - cpufreq_cpu_put(policy); + amd_pstate_update_perf(policy, min_perf, des_perf, max_perf, 0, fast_switch); } static int amd_pstate_verify(struct cpufreq_policy_data *policy_data) @@ -575,7 +585,8 @@ static int amd_pstate_verify(struct cpufreq_policy_data *policy_data) * amd-pstate qos_requests. */ if (policy_data->min == FREQ_QOS_MIN_DEFAULT_VALUE) { - struct cpufreq_policy *policy = cpufreq_cpu_get(policy_data->cpu); + struct cpufreq_policy *policy __free(put_cpufreq_policy) = + cpufreq_cpu_get(policy_data->cpu); struct amd_cpudata *cpudata; if (!policy) @@ -583,57 +594,51 @@ static int amd_pstate_verify(struct cpufreq_policy_data *policy_data) cpudata = policy->driver_data; policy_data->min = cpudata->lowest_nonlinear_freq; - cpufreq_cpu_put(policy); } cpufreq_verify_within_cpu_limits(policy_data); - pr_debug("policy_max =%d, policy_min=%d\n", policy_data->max, policy_data->min); return 0; } -static int amd_pstate_update_min_max_limit(struct cpufreq_policy *policy) +static void amd_pstate_update_min_max_limit(struct cpufreq_policy *policy) { - u8 max_limit_perf, min_limit_perf, max_perf; - u32 max_freq; struct amd_cpudata *cpudata = policy->driver_data; + union perf_cached perf = READ_ONCE(cpudata->perf); - max_perf = READ_ONCE(cpudata->highest_perf); - max_freq = READ_ONCE(cpudata->max_freq); - max_limit_perf = div_u64(policy->max * max_perf, max_freq); - min_limit_perf = div_u64(policy->min * max_perf, max_freq); - - if (cpudata->policy == CPUFREQ_POLICY_PERFORMANCE) - min_limit_perf = min(cpudata->nominal_perf, max_limit_perf); - - WRITE_ONCE(cpudata->max_limit_perf, max_limit_perf); - WRITE_ONCE(cpudata->min_limit_perf, min_limit_perf); + perf.max_limit_perf = freq_to_perf(perf, cpudata->nominal_freq, policy->max); WRITE_ONCE(cpudata->max_limit_freq, policy->max); - WRITE_ONCE(cpudata->min_limit_freq, policy->min); - return 0; + if (cpudata->policy == CPUFREQ_POLICY_PERFORMANCE) { + perf.min_limit_perf = min(perf.nominal_perf, perf.max_limit_perf); + WRITE_ONCE(cpudata->min_limit_freq, min(cpudata->nominal_freq, cpudata->max_limit_freq)); + } else { + perf.min_limit_perf = freq_to_perf(perf, cpudata->nominal_freq, policy->min); + WRITE_ONCE(cpudata->min_limit_freq, policy->min); + } + + WRITE_ONCE(cpudata->perf, perf); } static int amd_pstate_update_freq(struct cpufreq_policy *policy, unsigned int target_freq, bool fast_switch) { struct cpufreq_freqs freqs; - struct amd_cpudata *cpudata = policy->driver_data; - u8 des_perf, cap_perf; + struct amd_cpudata *cpudata; + union perf_cached perf; + u8 des_perf; - if (!cpudata->max_freq) - return -ENODEV; + cpudata = policy->driver_data; if (policy->min != cpudata->min_limit_freq || policy->max != cpudata->max_limit_freq) amd_pstate_update_min_max_limit(policy); - cap_perf = READ_ONCE(cpudata->highest_perf); + perf = READ_ONCE(cpudata->perf); freqs.old = policy->cur; freqs.new = target_freq; - des_perf = DIV_ROUND_CLOSEST(target_freq * cap_perf, - cpudata->max_freq); + des_perf = freq_to_perf(perf, cpudata->nominal_freq, target_freq); WARN_ON(fast_switch && !policy->fast_switch_enabled); /* @@ -644,8 +649,8 @@ static int amd_pstate_update_freq(struct cpufreq_policy *policy, if (!fast_switch) cpufreq_freq_transition_begin(policy, &freqs); - amd_pstate_update(cpudata, cpudata->min_limit_perf, des_perf, - cpudata->max_limit_perf, fast_switch, + amd_pstate_update(cpudata, perf.min_limit_perf, des_perf, + perf.max_limit_perf, fast_switch, policy->governor->flags); if (!fast_switch) @@ -674,9 +679,10 @@ static void amd_pstate_adjust_perf(unsigned int cpu, unsigned long target_perf, unsigned long capacity) { - u8 max_perf, min_perf, des_perf, cap_perf, min_limit_perf; - struct cpufreq_policy *policy = cpufreq_cpu_get(cpu); + u8 max_perf, min_perf, des_perf, cap_perf; + struct cpufreq_policy *policy __free(put_cpufreq_policy) = cpufreq_cpu_get(cpu); struct amd_cpudata *cpudata; + union perf_cached perf; if (!policy) return; @@ -686,8 +692,8 @@ static void amd_pstate_adjust_perf(unsigned int cpu, if (policy->min != cpudata->min_limit_freq || policy->max != cpudata->max_limit_freq) amd_pstate_update_min_max_limit(policy); - cap_perf = READ_ONCE(cpudata->highest_perf); - min_limit_perf = READ_ONCE(cpudata->min_limit_perf); + perf = READ_ONCE(cpudata->perf); + cap_perf = perf.highest_perf; des_perf = cap_perf; if (target_perf < capacity) @@ -698,28 +704,26 @@ static void amd_pstate_adjust_perf(unsigned int cpu, else min_perf = cap_perf; - if (min_perf < min_limit_perf) - min_perf = min_limit_perf; + if (min_perf < perf.min_limit_perf) + min_perf = perf.min_limit_perf; - max_perf = cpudata->max_limit_perf; + max_perf = perf.max_limit_perf; if (max_perf < min_perf) max_perf = min_perf; - des_perf = clamp_t(unsigned long, des_perf, min_perf, max_perf); - amd_pstate_update(cpudata, min_perf, des_perf, max_perf, true, policy->governor->flags); - cpufreq_cpu_put(policy); } static int amd_pstate_cpu_boost_update(struct cpufreq_policy *policy, bool on) { struct amd_cpudata *cpudata = policy->driver_data; + union perf_cached perf = READ_ONCE(cpudata->perf); u32 nominal_freq, max_freq; int ret = 0; nominal_freq = READ_ONCE(cpudata->nominal_freq); - max_freq = READ_ONCE(cpudata->max_freq); + max_freq = perf_to_freq(perf, cpudata->nominal_freq, perf.highest_perf); if (on) policy->cpuinfo.max_freq = max_freq; @@ -746,7 +750,6 @@ static int amd_pstate_set_boost(struct cpufreq_policy *policy, int state) pr_err("Boost mode is not supported by this processor or SBIOS\n"); return -EOPNOTSUPP; } - guard(mutex)(&amd_pstate_driver_lock); ret = amd_pstate_cpu_boost_update(policy, state); refresh_frequency_limits(policy); @@ -791,19 +794,9 @@ static void amd_perf_ctl_reset(unsigned int cpu) wrmsrl_on_cpu(cpu, MSR_AMD_PERF_CTL, 0); } -/* - * Set amd-pstate preferred core enable can't be done directly from cpufreq callbacks - * due to locking, so queue the work for later. - */ -static void amd_pstste_sched_prefcore_workfn(struct work_struct *work) -{ - sched_set_itmt_support(); -} -static DECLARE_WORK(sched_prefcore_work, amd_pstste_sched_prefcore_workfn); - #define CPPC_MAX_PERF U8_MAX -static void amd_pstate_init_prefcore(struct amd_cpudata *cpudata) +static void amd_pstate_init_asym_prio(struct amd_cpudata *cpudata) { /* user disabled or not detected */ if (!amd_pstate_prefcore) @@ -811,54 +804,38 @@ static void amd_pstate_init_prefcore(struct amd_cpudata *cpudata) cpudata->hw_prefcore = true; - /* - * The priorities can be set regardless of whether or not - * sched_set_itmt_support(true) has been called and it is valid to - * update them at any time after it has been called. - */ + /* The priorities must be initialized before ITMT support can be toggled on. */ sched_set_itmt_core_prio((int)READ_ONCE(cpudata->prefcore_ranking), cpudata->cpu); - - schedule_work(&sched_prefcore_work); } static void amd_pstate_update_limits(unsigned int cpu) { - struct cpufreq_policy *policy = NULL; + struct cpufreq_policy *policy __free(put_cpufreq_policy) = cpufreq_cpu_get(cpu); struct amd_cpudata *cpudata; u32 prev_high = 0, cur_high = 0; - int ret; bool highest_perf_changed = false; if (!amd_pstate_prefcore) return; - policy = cpufreq_cpu_get(cpu); if (!policy) return; - cpudata = policy->driver_data; - - guard(mutex)(&amd_pstate_driver_lock); - - ret = amd_get_highest_perf(cpu, &cur_high); - if (ret) { - cpufreq_cpu_put(policy); + if (amd_get_highest_perf(cpu, &cur_high)) return; - } + + cpudata = policy->driver_data; prev_high = READ_ONCE(cpudata->prefcore_ranking); highest_perf_changed = (prev_high != cur_high); if (highest_perf_changed) { WRITE_ONCE(cpudata->prefcore_ranking, cur_high); - if (cur_high < CPPC_MAX_PERF) + if (cur_high < CPPC_MAX_PERF) { sched_set_itmt_core_prio((int)cur_high, cpu); + sched_update_asym_prefer_cpu(cpu, prev_high, cur_high); + } } - cpufreq_cpu_put(policy); - - if (!highest_perf_changed) - cpufreq_update_policy(cpu); - } /* @@ -896,48 +873,45 @@ static u32 amd_pstate_get_transition_latency(unsigned int cpu) } /* - * amd_pstate_init_freq: Initialize the max_freq, min_freq, - * nominal_freq and lowest_nonlinear_freq for - * the @cpudata object. + * amd_pstate_init_freq: Initialize the nominal_freq and lowest_nonlinear_freq + * for the @cpudata object. * - * Requires: highest_perf, lowest_perf, nominal_perf and - * lowest_nonlinear_perf members of @cpudata to be - * initialized. + * Requires: all perf members of @cpudata to be initialized. * - * Returns 0 on success, non-zero value on failure. + * Returns 0 on success, non-zero value on failure. */ static int amd_pstate_init_freq(struct amd_cpudata *cpudata) { - int ret; - u32 min_freq, max_freq; - u8 highest_perf, nominal_perf, lowest_nonlinear_perf; - u32 nominal_freq, lowest_nonlinear_freq; + u32 min_freq, max_freq, nominal_freq, lowest_nonlinear_freq; struct cppc_perf_caps cppc_perf; + union perf_cached perf; + int ret; ret = cppc_get_perf_caps(cpudata->cpu, &cppc_perf); if (ret) return ret; + perf = READ_ONCE(cpudata->perf); - if (quirks && quirks->lowest_freq) + if (quirks && quirks->nominal_freq) + nominal_freq = quirks->nominal_freq; + else + nominal_freq = cppc_perf.nominal_freq; + nominal_freq *= 1000; + + if (quirks && quirks->lowest_freq) { min_freq = quirks->lowest_freq; - else + perf.lowest_perf = freq_to_perf(perf, nominal_freq, min_freq); + WRITE_ONCE(cpudata->perf, perf); + } else min_freq = cppc_perf.lowest_freq; - if (quirks && quirks->nominal_freq) - nominal_freq = quirks->nominal_freq; - else - nominal_freq = cppc_perf.nominal_freq; + min_freq *= 1000; - highest_perf = READ_ONCE(cpudata->highest_perf); - nominal_perf = READ_ONCE(cpudata->nominal_perf); - max_freq = div_u64((u64)highest_perf * nominal_freq, nominal_perf); + WRITE_ONCE(cpudata->nominal_freq, nominal_freq); - lowest_nonlinear_perf = READ_ONCE(cpudata->lowest_nonlinear_perf); - lowest_nonlinear_freq = div_u64((u64)nominal_freq * lowest_nonlinear_perf, nominal_perf); - WRITE_ONCE(cpudata->min_freq, min_freq * 1000); - WRITE_ONCE(cpudata->lowest_nonlinear_freq, lowest_nonlinear_freq * 1000); - WRITE_ONCE(cpudata->nominal_freq, nominal_freq * 1000); - WRITE_ONCE(cpudata->max_freq, max_freq * 1000); + max_freq = perf_to_freq(perf, nominal_freq, perf.highest_perf); + lowest_nonlinear_freq = perf_to_freq(perf, nominal_freq, perf.lowest_nonlinear_perf); + WRITE_ONCE(cpudata->lowest_nonlinear_freq, lowest_nonlinear_freq); /** * Below values need to be initialized correctly, otherwise driver will fail to load @@ -962,9 +936,10 @@ static int amd_pstate_init_freq(struct amd_cpudata *cpudata) static int amd_pstate_cpu_init(struct cpufreq_policy *policy) { - int min_freq, max_freq, ret; - struct device *dev; struct amd_cpudata *cpudata; + union perf_cached perf; + struct device *dev; + int ret; /* * Resetting PERF_CTL_MSR will put the CPU in P0 frequency, @@ -985,7 +960,7 @@ static int amd_pstate_cpu_init(struct cpufreq_policy *policy) if (ret) goto free_cpudata1; - amd_pstate_init_prefcore(cpudata); + amd_pstate_init_asym_prio(cpudata); ret = amd_pstate_init_freq(cpudata); if (ret) @@ -995,17 +970,21 @@ static int amd_pstate_cpu_init(struct cpufreq_policy *policy) if (ret) goto free_cpudata1; - min_freq = READ_ONCE(cpudata->min_freq); - max_freq = READ_ONCE(cpudata->max_freq); - policy->cpuinfo.transition_latency = amd_pstate_get_transition_latency(policy->cpu); policy->transition_delay_us = amd_pstate_get_transition_delay_us(policy->cpu); - policy->min = min_freq; - policy->max = max_freq; + perf = READ_ONCE(cpudata->perf); - policy->cpuinfo.min_freq = min_freq; - policy->cpuinfo.max_freq = max_freq; + policy->cpuinfo.min_freq = policy->min = perf_to_freq(perf, + cpudata->nominal_freq, + perf.lowest_perf); + policy->cpuinfo.max_freq = policy->max = perf_to_freq(perf, + cpudata->nominal_freq, + perf.highest_perf); + + ret = amd_pstate_cppc_enable(policy); + if (ret) + goto free_cpudata1; policy->boost_enabled = READ_ONCE(cpudata->boost_supported); @@ -1029,9 +1008,6 @@ static int amd_pstate_cpu_init(struct cpufreq_policy *policy) goto free_cpudata2; } - cpudata->max_limit_freq = max_freq; - cpudata->min_limit_freq = min_freq; - policy->driver_data = cpudata; if (!current_pstate_driver->adjust_perf) @@ -1042,6 +1018,7 @@ static int amd_pstate_cpu_init(struct cpufreq_policy *policy) free_cpudata2: freq_qos_remove_request(&cpudata->req[0]); free_cpudata1: + pr_warn("Failed to initialize CPU %d: %d\n", policy->cpu, ret); kfree(cpudata); return ret; } @@ -1056,28 +1033,6 @@ static void amd_pstate_cpu_exit(struct cpufreq_policy *policy) kfree(cpudata); } -static int amd_pstate_cpu_resume(struct cpufreq_policy *policy) -{ - int ret; - - ret = amd_pstate_cppc_enable(true); - if (ret) - pr_err("failed to enable amd-pstate during resume, return %d\n", ret); - - return ret; -} - -static int amd_pstate_cpu_suspend(struct cpufreq_policy *policy) -{ - int ret; - - ret = amd_pstate_cppc_enable(false); - if (ret) - pr_err("failed to disable amd-pstate during suspend, return %d\n", ret); - - return ret; -} - /* Sysfs attributes */ /* @@ -1088,27 +1043,27 @@ static int amd_pstate_cpu_suspend(struct cpufreq_policy *policy) static ssize_t show_amd_pstate_max_freq(struct cpufreq_policy *policy, char *buf) { - int max_freq; - struct amd_cpudata *cpudata = policy->driver_data; + struct amd_cpudata *cpudata; + union perf_cached perf; - max_freq = READ_ONCE(cpudata->max_freq); - if (max_freq < 0) - return max_freq; + cpudata = policy->driver_data; + perf = READ_ONCE(cpudata->perf); - return sysfs_emit(buf, "%u\n", max_freq); + return sysfs_emit(buf, "%u\n", + perf_to_freq(perf, cpudata->nominal_freq, perf.highest_perf)); } static ssize_t show_amd_pstate_lowest_nonlinear_freq(struct cpufreq_policy *policy, char *buf) { - int freq; - struct amd_cpudata *cpudata = policy->driver_data; + struct amd_cpudata *cpudata; + union perf_cached perf; - freq = READ_ONCE(cpudata->lowest_nonlinear_freq); - if (freq < 0) - return freq; + cpudata = policy->driver_data; + perf = READ_ONCE(cpudata->perf); - return sysfs_emit(buf, "%u\n", freq); + return sysfs_emit(buf, "%u\n", + perf_to_freq(perf, cpudata->nominal_freq, perf.lowest_nonlinear_perf)); } /* @@ -1118,12 +1073,11 @@ static ssize_t show_amd_pstate_lowest_nonlinear_freq(struct cpufreq_policy *poli static ssize_t show_amd_pstate_highest_perf(struct cpufreq_policy *policy, char *buf) { - u8 perf; - struct amd_cpudata *cpudata = policy->driver_data; + struct amd_cpudata *cpudata; - perf = READ_ONCE(cpudata->highest_perf); + cpudata = policy->driver_data; - return sysfs_emit(buf, "%u\n", perf); + return sysfs_emit(buf, "%u\n", cpudata->perf.highest_perf); } static ssize_t show_amd_pstate_prefcore_ranking(struct cpufreq_policy *policy, @@ -1170,8 +1124,10 @@ static ssize_t show_energy_performance_available_preferences( static ssize_t store_energy_performance_preference( struct cpufreq_policy *policy, const char *buf, size_t count) { + struct amd_cpudata *cpudata = policy->driver_data; char str_preference[21]; ssize_t ret; + u8 epp; ret = sscanf(buf, "%20s", str_preference); if (ret != 1) @@ -1181,9 +1137,17 @@ static ssize_t store_energy_performance_preference( if (ret < 0) return -EINVAL; - guard(mutex)(&amd_pstate_limits_lock); + if (!ret) + epp = cpudata->epp_default; + else + epp = epp_values[ret]; + + if (epp > 0 && policy->policy == CPUFREQ_POLICY_PERFORMANCE) { + pr_debug("EPP cannot be set under performance policy\n"); + return -EBUSY; + } - ret = amd_pstate_set_energy_pref_index(policy, ret); + ret = amd_pstate_set_epp(policy, epp); return ret ? ret : count; } @@ -1192,9 +1156,11 @@ static ssize_t show_energy_performance_preference( struct cpufreq_policy *policy, char *buf) { struct amd_cpudata *cpudata = policy->driver_data; - u8 preference; + u8 preference, epp; - switch (cpudata->epp_cached) { + epp = FIELD_GET(AMD_CPPC_EPP_PERF_MASK, cpudata->cppc_req_cached); + + switch (epp) { case AMD_CPPC_EPP_PERFORMANCE: preference = EPP_INDEX_PERFORMANCE; break; @@ -1216,7 +1182,6 @@ static ssize_t show_energy_performance_preference( static void amd_pstate_driver_cleanup(void) { - amd_pstate_cppc_enable(false); cppc_state = AMD_PSTATE_DISABLE; current_pstate_driver = NULL; } @@ -1250,14 +1215,6 @@ static int amd_pstate_register_driver(int mode) cppc_state = mode; - ret = amd_pstate_cppc_enable(true); - if (ret) { - pr_err("failed to enable cppc during amd-pstate driver registration, return %d\n", - ret); - amd_pstate_driver_cleanup(); - return ret; - } - /* at least one CPU supports CPB */ current_pstate_driver->boost_enabled = cpu_feature_enabled(X86_FEATURE_CPB); @@ -1355,8 +1312,10 @@ int amd_pstate_update_status(const char *buf, size_t size) if (mode_idx < 0 || mode_idx >= AMD_PSTATE_MAX) return -EINVAL; - if (mode_state_machine[cppc_state][mode_idx]) + if (mode_state_machine[cppc_state][mode_idx]) { + guard(mutex)(&amd_pstate_driver_lock); return mode_state_machine[cppc_state][mode_idx](mode_idx); + } return 0; } @@ -1377,7 +1336,6 @@ static ssize_t status_store(struct device *a, struct device_attribute *b, char *p = memchr(buf, '\n', count); int ret; - guard(mutex)(&amd_pstate_driver_lock); ret = amd_pstate_update_status(buf, p ? p - buf : count); return ret < 0 ? ret : count; @@ -1453,10 +1411,11 @@ static bool amd_pstate_acpi_pm_profile_undefined(void) static int amd_pstate_epp_cpu_init(struct cpufreq_policy *policy) { - int min_freq, max_freq, ret; struct amd_cpudata *cpudata; + union perf_cached perf; struct device *dev; u64 value; + int ret; /* * Resetting PERF_CTL_MSR will put the CPU in P0 frequency, @@ -1477,7 +1436,7 @@ static int amd_pstate_epp_cpu_init(struct cpufreq_policy *policy) if (ret) goto free_cpudata1; - amd_pstate_init_prefcore(cpudata); + amd_pstate_init_asym_prio(cpudata); ret = amd_pstate_init_freq(cpudata); if (ret) @@ -1487,18 +1446,23 @@ static int amd_pstate_epp_cpu_init(struct cpufreq_policy *policy) if (ret) goto free_cpudata1; - min_freq = READ_ONCE(cpudata->min_freq); - max_freq = READ_ONCE(cpudata->max_freq); + perf = READ_ONCE(cpudata->perf); + + policy->cpuinfo.min_freq = policy->min = perf_to_freq(perf, + cpudata->nominal_freq, + perf.lowest_perf); + policy->cpuinfo.max_freq = policy->max = perf_to_freq(perf, + cpudata->nominal_freq, + perf.highest_perf); + policy->driver_data = cpudata; + + ret = amd_pstate_cppc_enable(policy); + if (ret) + goto free_cpudata1; - policy->cpuinfo.min_freq = min_freq; - policy->cpuinfo.max_freq = max_freq; /* It will be updated by governor */ policy->cur = policy->cpuinfo.min_freq; - policy->driver_data = cpudata; - - policy->min = policy->cpuinfo.min_freq; - policy->max = policy->cpuinfo.max_freq; policy->boost_enabled = READ_ONCE(cpudata->boost_supported); @@ -1520,13 +1484,8 @@ static int amd_pstate_epp_cpu_init(struct cpufreq_policy *policy) if (ret) return ret; WRITE_ONCE(cpudata->cppc_req_cached, value); - - ret = rdmsrl_on_cpu(cpudata->cpu, MSR_AMD_CPPC_CAP1, &value); - if (ret) - return ret; - WRITE_ONCE(cpudata->cppc_cap1_cached, value); } - ret = amd_pstate_set_epp(cpudata, cpudata->epp_default); + ret = amd_pstate_set_epp(policy, cpudata->epp_default); if (ret) return ret; @@ -1535,6 +1494,7 @@ static int amd_pstate_epp_cpu_init(struct cpufreq_policy *policy) return 0; free_cpudata1: + pr_warn("Failed to initialize CPU %d: %d\n", policy->cpu, ret); kfree(cpudata); return ret; } @@ -1554,24 +1514,21 @@ static void amd_pstate_epp_cpu_exit(struct cpufreq_policy *policy) static int amd_pstate_epp_update_limit(struct cpufreq_policy *policy) { struct amd_cpudata *cpudata = policy->driver_data; + union perf_cached perf; u8 epp; - amd_pstate_update_min_max_limit(policy); + if (policy->min != cpudata->min_limit_freq || policy->max != cpudata->max_limit_freq) + amd_pstate_update_min_max_limit(policy); if (cpudata->policy == CPUFREQ_POLICY_PERFORMANCE) epp = 0; else - epp = READ_ONCE(cpudata->epp_cached); + epp = FIELD_GET(AMD_CPPC_EPP_PERF_MASK, cpudata->cppc_req_cached); - if (trace_amd_pstate_epp_perf_enabled()) { - trace_amd_pstate_epp_perf(cpudata->cpu, cpudata->highest_perf, epp, - cpudata->min_limit_perf, - cpudata->max_limit_perf, - policy->boost_enabled); - } + perf = READ_ONCE(cpudata->perf); - return amd_pstate_update_perf(cpudata, cpudata->min_limit_perf, 0U, - cpudata->max_limit_perf, epp, false); + return amd_pstate_update_perf(policy, perf.min_limit_perf, 0U, + perf.max_limit_perf, epp, false); } static int amd_pstate_epp_set_policy(struct cpufreq_policy *policy) @@ -1582,9 +1539,6 @@ static int amd_pstate_epp_set_policy(struct cpufreq_policy *policy) if (!policy->cpuinfo.max_freq) return -ENODEV; - pr_debug("set_policy: cpuinfo.max %u policy->max %u\n", - policy->cpuinfo.max_freq, policy->max); - cpudata->policy = policy->policy; ret = amd_pstate_epp_update_limit(policy); @@ -1600,73 +1554,21 @@ static int amd_pstate_epp_set_policy(struct cpufreq_policy *policy) return 0; } -static int amd_pstate_epp_reenable(struct cpufreq_policy *policy) -{ - struct amd_cpudata *cpudata = policy->driver_data; - u8 max_perf; - int ret; - - ret = amd_pstate_cppc_enable(true); - if (ret) - pr_err("failed to enable amd pstate during resume, return %d\n", ret); - - max_perf = READ_ONCE(cpudata->highest_perf); - - if (trace_amd_pstate_epp_perf_enabled()) { - trace_amd_pstate_epp_perf(cpudata->cpu, cpudata->highest_perf, - cpudata->epp_cached, - FIELD_GET(AMD_CPPC_MIN_PERF_MASK, cpudata->cppc_req_cached), - max_perf, policy->boost_enabled); - } - - return amd_pstate_epp_update_limit(policy); -} - static int amd_pstate_epp_cpu_online(struct cpufreq_policy *policy) { - struct amd_cpudata *cpudata = policy->driver_data; - int ret; + pr_debug("AMD CPU Core %d going online\n", policy->cpu); - pr_debug("AMD CPU Core %d going online\n", cpudata->cpu); - - ret = amd_pstate_epp_reenable(policy); - if (ret) - return ret; - cpudata->suspended = false; - - return 0; + return amd_pstate_cppc_enable(policy); } static int amd_pstate_epp_cpu_offline(struct cpufreq_policy *policy) { - struct amd_cpudata *cpudata = policy->driver_data; - u8 min_perf; - - if (cpudata->suspended) - return 0; - - min_perf = READ_ONCE(cpudata->lowest_perf); - - guard(mutex)(&amd_pstate_limits_lock); - - if (trace_amd_pstate_epp_perf_enabled()) { - trace_amd_pstate_epp_perf(cpudata->cpu, cpudata->highest_perf, - AMD_CPPC_EPP_BALANCE_POWERSAVE, - min_perf, min_perf, policy->boost_enabled); - } - - return amd_pstate_update_perf(cpudata, min_perf, 0, min_perf, - AMD_CPPC_EPP_BALANCE_POWERSAVE, false); + return 0; } static int amd_pstate_epp_suspend(struct cpufreq_policy *policy) { struct amd_cpudata *cpudata = policy->driver_data; - int ret; - - /* avoid suspending when EPP is not enabled */ - if (cppc_state != AMD_PSTATE_ACTIVE) - return 0; /* invalidate to ensure it's rewritten during resume */ cpudata->cppc_req_cached = 0; @@ -1674,11 +1576,6 @@ static int amd_pstate_epp_suspend(struct cpufreq_policy *policy) /* set this flag to avoid setting core offline*/ cpudata->suspended = true; - /* disable CPPC in lowlevel firmware */ - ret = amd_pstate_cppc_enable(false); - if (ret) - pr_err("failed to suspend, return %d\n", ret); - return 0; } @@ -1687,10 +1584,12 @@ static int amd_pstate_epp_resume(struct cpufreq_policy *policy) struct amd_cpudata *cpudata = policy->driver_data; if (cpudata->suspended) { - guard(mutex)(&amd_pstate_limits_lock); + int ret; /* enable amd pstate from suspend state*/ - amd_pstate_epp_reenable(policy); + ret = amd_pstate_epp_update_limit(policy); + if (ret) + return ret; cpudata->suspended = false; } @@ -1705,8 +1604,6 @@ static struct cpufreq_driver amd_pstate_driver = { .fast_switch = amd_pstate_fast_switch, .init = amd_pstate_cpu_init, .exit = amd_pstate_cpu_exit, - .suspend = amd_pstate_cpu_suspend, - .resume = amd_pstate_cpu_resume, .set_boost = amd_pstate_set_boost, .update_limits = amd_pstate_update_limits, .name = "amd-pstate", @@ -1869,11 +1766,14 @@ static int __init amd_pstate_init(void) } } + /* Enable ITMT support once all CPUs have initialized their asym priorities. */ + if (amd_pstate_prefcore) + sched_set_itmt_support(); + return ret; global_attr_free: cpufreq_unregister_driver(current_pstate_driver); - amd_pstate_cppc_enable(false); return ret; } device_initcall(amd_pstate_init); diff --git a/drivers/cpufreq/amd-pstate.h b/drivers/cpufreq/amd-pstate.h index 19d405c6d805..fbe1c08d3f06 100644 --- a/drivers/cpufreq/amd-pstate.h +++ b/drivers/cpufreq/amd-pstate.h @@ -13,6 +13,36 @@ /********************************************************************* * AMD P-state INTERFACE * *********************************************************************/ + +/** + * union perf_cached - A union to cache performance-related data. + * @highest_perf: the maximum performance an individual processor may reach, + * assuming ideal conditions + * For platforms that support the preferred core feature, the highest_perf value maybe + * configured to any value in the range 166-255 by the firmware (because the preferred + * core ranking is encoded in the highest_perf value). To maintain consistency across + * all platforms, we split the highest_perf and preferred core ranking values into + * cpudata->perf.highest_perf and cpudata->prefcore_ranking. + * @nominal_perf: the maximum sustained performance level of the processor, + * assuming ideal operating conditions + * @lowest_nonlinear_perf: the lowest performance level at which nonlinear power + * savings are achieved + * @lowest_perf: the absolute lowest performance level of the processor + * @min_limit_perf: Cached value of the performance corresponding to policy->min + * @max_limit_perf: Cached value of the performance corresponding to policy->max + */ +union perf_cached { + struct { + u8 highest_perf; + u8 nominal_perf; + u8 lowest_nonlinear_perf; + u8 lowest_perf; + u8 min_limit_perf; + u8 max_limit_perf; + }; + u64 val; +}; + /** * struct amd_aperf_mperf * @aperf: actual performance frequency clock count @@ -30,24 +60,11 @@ struct amd_aperf_mperf { * @cpu: CPU number * @req: constraint request to apply * @cppc_req_cached: cached performance request hints - * @highest_perf: the maximum performance an individual processor may reach, - * assuming ideal conditions - * For platforms that do not support the preferred core feature, the - * highest_pef may be configured with 166 or 255, to avoid max frequency - * calculated wrongly. we take the fixed value as the highest_perf. - * @nominal_perf: the maximum sustained performance level of the processor, - * assuming ideal operating conditions - * @lowest_nonlinear_perf: the lowest performance level at which nonlinear power - * savings are achieved - * @lowest_perf: the absolute lowest performance level of the processor + * @perf: cached performance-related data * @prefcore_ranking: the preferred core ranking, the higher value indicates a higher * priority. - * @min_limit_perf: Cached value of the performance corresponding to policy->min - * @max_limit_perf: Cached value of the performance corresponding to policy->max * @min_limit_freq: Cached value of policy->min (in khz) * @max_limit_freq: Cached value of policy->max (in khz) - * @max_freq: the frequency (in khz) that mapped to highest_perf - * @min_freq: the frequency (in khz) that mapped to lowest_perf * @nominal_freq: the frequency (in khz) that mapped to nominal_perf * @lowest_nonlinear_freq: the frequency (in khz) that mapped to lowest_nonlinear_perf * @cur: Difference of Aperf/Mperf/tsc count between last and current sample @@ -59,7 +76,6 @@ struct amd_aperf_mperf { * AMD P-State driver supports preferred core featue. * @epp_cached: Cached CPPC energy-performance preference value * @policy: Cpufreq policy value - * @cppc_cap1_cached Cached MSR_AMD_CPPC_CAP1 register value * * The amd_cpudata is key private data for each CPU thread in AMD P-State, and * represents all the attributes and goals that AMD P-State requests at runtime. @@ -70,18 +86,11 @@ struct amd_cpudata { struct freq_qos_request req[2]; u64 cppc_req_cached; - u8 highest_perf; - u8 nominal_perf; - u8 lowest_nonlinear_perf; - u8 lowest_perf; - u8 prefcore_ranking; - u8 min_limit_perf; - u8 max_limit_perf; - u32 min_limit_freq; - u32 max_limit_freq; + union perf_cached perf; - u32 max_freq; - u32 min_freq; + u8 prefcore_ranking; + u32 min_limit_freq; + u32 max_limit_freq; u32 nominal_freq; u32 lowest_nonlinear_freq; @@ -93,9 +102,7 @@ struct amd_cpudata { bool hw_prefcore; /* EPP feature related attributes*/ - u8 epp_cached; u32 policy; - u64 cppc_cap1_cached; bool suspended; u8 epp_default; }; diff --git a/include/linux/cpufreq.h b/include/linux/cpufreq.h index 7fe0981a7e46..dde5212d256c 100644 --- a/include/linux/cpufreq.h +++ b/include/linux/cpufreq.h @@ -210,6 +210,9 @@ static inline struct cpufreq_policy *cpufreq_cpu_get(unsigned int cpu) static inline void cpufreq_cpu_put(struct cpufreq_policy *policy) { } #endif +/* Scope based cleanup macro for cpufreq_policy kobject reference counting */ +DEFINE_FREE(put_cpufreq_policy, struct cpufreq_policy *, if (_T) cpufreq_cpu_put(_T)) + static inline bool policy_is_inactive(struct cpufreq_policy *policy) { return cpumask_empty(policy->cpus); diff --git a/include/linux/sched/topology.h b/include/linux/sched/topology.h index 7f3dbafe1817..85a772968af3 100644 --- a/include/linux/sched/topology.h +++ b/include/linux/sched/topology.h @@ -203,6 +203,8 @@ struct sched_domain_topology_level { }; extern void __init set_sched_topology(struct sched_domain_topology_level *tl); +extern void sched_update_asym_prefer_cpu(int cpu, int old_prio, int new_prio); + # define SD_INIT_NAME(type) .name = #type @@ -237,6 +239,10 @@ static inline bool cpus_share_resources(int this_cpu, int that_cpu) return true; } +static inline void sched_update_asym_prefer_cpu(int cpu, int old_prio, int new_prio) +{ +} + #endif /* !CONFIG_SMP */ #if defined(CONFIG_ENERGY_MODEL) && defined(CONFIG_CPU_FREQ_GOV_SCHEDUTIL) diff --git a/kernel/sched/debug.c b/kernel/sched/debug.c index a0893a483d35..228d981dea5d 100644 --- a/kernel/sched/debug.c +++ b/kernel/sched/debug.c @@ -586,6 +586,10 @@ static void register_sd(struct sched_domain *sd, struct dentry *parent) debugfs_create_file("flags", 0444, parent, &sd->flags, &sd_flags_fops); debugfs_create_file("groups_flags", 0444, parent, &sd->groups->flags, &sd_flags_fops); debugfs_create_u32("level", 0444, parent, (u32 *)&sd->level); + + if (sd->flags & SD_ASYM_PACKING) + debugfs_create_u32("group_asym_prefer_cpu", 0444, parent, + (u32 *)&sd->groups->asym_prefer_cpu); } void update_sched_domain_debugfs(void) diff --git a/kernel/sched/fair.c b/kernel/sched/fair.c index 89c7260103e1..e197b0edcfb2 100644 --- a/kernel/sched/fair.c +++ b/kernel/sched/fair.c @@ -10261,7 +10261,7 @@ sched_group_asym(struct lb_env *env, struct sg_lb_stats *sgs, struct sched_group (sgs->group_weight - sgs->idle_cpus != 1)) return false; - return sched_asym(env->sd, env->dst_cpu, group->asym_prefer_cpu); + return sched_asym(env->sd, env->dst_cpu, READ_ONCE(group->asym_prefer_cpu)); } /* One group has more than one SMT CPU while the other group does not */ @@ -10498,7 +10498,8 @@ static bool update_sd_pick_busiest(struct lb_env *env, case group_asym_packing: /* Prefer to move from lowest priority CPU's work */ - return sched_asym_prefer(sds->busiest->asym_prefer_cpu, sg->asym_prefer_cpu); + return sched_asym_prefer(READ_ONCE(sds->busiest->asym_prefer_cpu), + READ_ONCE(sg->asym_prefer_cpu)); case group_misfit_task: /* diff --git a/kernel/sched/topology.c b/kernel/sched/topology.c index 363ad268a25b..879c1d6df5d8 100644 --- a/kernel/sched/topology.c +++ b/kernel/sched/topology.c @@ -1344,6 +1344,64 @@ static void init_sched_groups_capacity(int cpu, struct sched_domain *sd) update_group_capacity(sd, cpu); } +#ifdef CONFIG_SMP + +/* Update the "asym_prefer_cpu" when arch_asym_cpu_priority() changes. */ +void sched_update_asym_prefer_cpu(int cpu, int old_prio, int new_prio) +{ + int asym_prefer_cpu = cpu; + struct sched_domain *sd; + + guard(rcu)(); + + for_each_domain(cpu, sd) { + struct sched_group *sg; + int group_cpu; + + if (!(sd->flags & SD_ASYM_PACKING)) + continue; + + /* + * Groups of overlapping domain are replicated per NUMA + * node and will require updating "asym_prefer_cpu" on + * each local copy. + * + * If you are hitting this warning, consider moving + * "sg->asym_prefer_cpu" to "sg->sgc->asym_prefer_cpu" + * which is shared by all the overlapping groups. + */ + WARN_ON_ONCE(sd->flags & SD_OVERLAP); + + sg = sd->groups; + if (cpu != sg->asym_prefer_cpu) { + /* + * Since the parent is a superset of the current group, + * if the cpu is not the "asym_prefer_cpu" at the + * current level, it cannot be the preferred CPU at a + * higher levels either. + */ + if (!sched_asym_prefer(cpu, sg->asym_prefer_cpu)) + return; + + WRITE_ONCE(sg->asym_prefer_cpu, cpu); + continue; + } + + /* Ranking has improved; CPU is still the preferred one. */ + if (new_prio >= old_prio) + continue; + + for_each_cpu(group_cpu, sched_group_span(sg)) { + if (sched_asym_prefer(group_cpu, asym_prefer_cpu)) + asym_prefer_cpu = group_cpu; + } + + WRITE_ONCE(sg->asym_prefer_cpu, asym_prefer_cpu); + } +} + +#endif /* CONFIG_SMP */ + /* * Set of available CPUs grouped by their corresponding capacities * Each list entry contains a CPU mask reflecting CPUs that share the same -- 2.49.0.391.g4bbb303af6 From 4acb0ffcd8f95102a62542dbf9d83f57c89c2090 Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Sun, 20 Apr 2025 15:31:45 +0200 Subject: [PATCH 2/9] amd-tlb-broadcast Signed-off-by: Peter Jung --- arch/x86/Kconfig | 2 +- arch/x86/Kconfig.cpu | 4 + arch/x86/hyperv/mmu.c | 1 - arch/x86/include/asm/cpufeatures.h | 1 + arch/x86/include/asm/disabled-features.h | 8 +- arch/x86/include/asm/mmu.h | 12 + arch/x86/include/asm/mmu_context.h | 10 +- arch/x86/include/asm/msr-index.h | 2 + arch/x86/include/asm/paravirt.h | 5 - arch/x86/include/asm/paravirt_types.h | 2 - arch/x86/include/asm/tlb.h | 138 ++++++++ arch/x86/include/asm/tlbbatch.h | 5 + arch/x86/include/asm/tlbflush.h | 70 ++++ arch/x86/kernel/cpu/amd.c | 10 + arch/x86/kernel/kvm.c | 1 - arch/x86/kernel/paravirt.c | 16 - arch/x86/mm/pgtable.c | 27 +- arch/x86/mm/tlb.c | 427 +++++++++++++++++++++-- arch/x86/xen/mmu_pv.c | 1 - tools/arch/x86/include/asm/msr-index.h | 2 + 20 files changed, 655 insertions(+), 89 deletions(-) diff --git a/arch/x86/Kconfig b/arch/x86/Kconfig index aeb95b6e5536..088f7555e1ac 100644 --- a/arch/x86/Kconfig +++ b/arch/x86/Kconfig @@ -277,7 +277,7 @@ config X86 select HAVE_PCI select HAVE_PERF_REGS select HAVE_PERF_USER_STACK_DUMP - select MMU_GATHER_RCU_TABLE_FREE if PARAVIRT + select MMU_GATHER_RCU_TABLE_FREE select MMU_GATHER_MERGE_VMAS select HAVE_POSIX_CPU_TIMERS_TASK_WORK select HAVE_REGS_AND_STACK_ACCESS_API diff --git a/arch/x86/Kconfig.cpu b/arch/x86/Kconfig.cpu index 42e6a40876ea..9bade26f6c67 100644 --- a/arch/x86/Kconfig.cpu +++ b/arch/x86/Kconfig.cpu @@ -401,6 +401,10 @@ menuconfig PROCESSOR_SELECT This lets you choose what x86 vendor support code your kernel will include. +config BROADCAST_TLB_FLUSH + def_bool y + depends on CPU_SUP_AMD && 64BIT + config CPU_SUP_INTEL default y bool "Support Intel processors" if PROCESSOR_SELECT diff --git a/arch/x86/hyperv/mmu.c b/arch/x86/hyperv/mmu.c index cc8c3bd0e7c2..1f7c3082a36d 100644 --- a/arch/x86/hyperv/mmu.c +++ b/arch/x86/hyperv/mmu.c @@ -239,5 +239,4 @@ void hyperv_setup_mmu_ops(void) pr_info("Using hypercall for remote TLB flush\n"); pv_ops.mmu.flush_tlb_multi = hyperv_flush_tlb_multi; - pv_ops.mmu.tlb_remove_table = tlb_remove_table; } diff --git a/arch/x86/include/asm/cpufeatures.h b/arch/x86/include/asm/cpufeatures.h index 508c0dad116b..8770dc185fe9 100644 --- a/arch/x86/include/asm/cpufeatures.h +++ b/arch/x86/include/asm/cpufeatures.h @@ -338,6 +338,7 @@ #define X86_FEATURE_CLZERO (13*32+ 0) /* "clzero" CLZERO instruction */ #define X86_FEATURE_IRPERF (13*32+ 1) /* "irperf" Instructions Retired Count */ #define X86_FEATURE_XSAVEERPTR (13*32+ 2) /* "xsaveerptr" Always save/restore FP error pointers */ +#define X86_FEATURE_INVLPGB (13*32+ 3) /* INVLPGB and TLBSYNC instructions supported */ #define X86_FEATURE_RDPRU (13*32+ 4) /* "rdpru" Read processor register at user level */ #define X86_FEATURE_WBNOINVD (13*32+ 9) /* "wbnoinvd" WBNOINVD instruction */ #define X86_FEATURE_AMD_IBPB (13*32+12) /* Indirect Branch Prediction Barrier */ diff --git a/arch/x86/include/asm/disabled-features.h b/arch/x86/include/asm/disabled-features.h index c492bdc97b05..be8c38855068 100644 --- a/arch/x86/include/asm/disabled-features.h +++ b/arch/x86/include/asm/disabled-features.h @@ -129,6 +129,12 @@ #define DISABLE_SEV_SNP (1 << (X86_FEATURE_SEV_SNP & 31)) #endif +#ifdef CONFIG_BROADCAST_TLB_FLUSH +#define DISABLE_INVLPGB 0 +#else +#define DISABLE_INVLPGB (1 << (X86_FEATURE_INVLPGB & 31)) +#endif + /* * Make sure to add features to the correct mask */ @@ -146,7 +152,7 @@ #define DISABLED_MASK11 (DISABLE_RETPOLINE|DISABLE_RETHUNK|DISABLE_UNRET| \ DISABLE_CALL_DEPTH_TRACKING|DISABLE_USER_SHSTK) #define DISABLED_MASK12 (DISABLE_FRED|DISABLE_LAM) -#define DISABLED_MASK13 0 +#define DISABLED_MASK13 (DISABLE_INVLPGB) #define DISABLED_MASK14 0 #define DISABLED_MASK15 0 #define DISABLED_MASK16 (DISABLE_PKU|DISABLE_OSPKE|DISABLE_LA57|DISABLE_UMIP| \ diff --git a/arch/x86/include/asm/mmu.h b/arch/x86/include/asm/mmu.h index 3b496cdcb74b..8b8055a8eb9e 100644 --- a/arch/x86/include/asm/mmu.h +++ b/arch/x86/include/asm/mmu.h @@ -69,6 +69,18 @@ typedef struct { u16 pkey_allocation_map; s16 execute_only_pkey; #endif + +#ifdef CONFIG_BROADCAST_TLB_FLUSH + /* + * The global ASID will be a non-zero value when the process has + * the same ASID across all CPUs, allowing it to make use of + * hardware-assisted remote TLB invalidation like AMD INVLPGB. + */ + u16 global_asid; + + /* The process is transitioning to a new global ASID number. */ + bool asid_transition; +#endif } mm_context_t; #define INIT_MM_CONTEXT(mm) \ diff --git a/arch/x86/include/asm/mmu_context.h b/arch/x86/include/asm/mmu_context.h index 795fdd53bd0a..2398058b6e83 100644 --- a/arch/x86/include/asm/mmu_context.h +++ b/arch/x86/include/asm/mmu_context.h @@ -2,7 +2,6 @@ #ifndef _ASM_X86_MMU_CONTEXT_H #define _ASM_X86_MMU_CONTEXT_H -#include #include #include #include @@ -13,6 +12,7 @@ #include #include #include +#include extern atomic64_t last_mm_ctx_id; @@ -139,6 +139,11 @@ static inline void mm_reset_untag_mask(struct mm_struct *mm) #define enter_lazy_tlb enter_lazy_tlb extern void enter_lazy_tlb(struct mm_struct *mm, struct task_struct *tsk); +#define mm_init_global_asid mm_init_global_asid +extern void mm_init_global_asid(struct mm_struct *mm); + +extern void mm_free_global_asid(struct mm_struct *mm); + /* * Init a new mm. Used on mm copies, like at fork() * and on mm's that are brand-new, like at execve(). @@ -161,6 +166,8 @@ static inline int init_new_context(struct task_struct *tsk, mm->context.execute_only_pkey = -1; } #endif + + mm_init_global_asid(mm); mm_reset_untag_mask(mm); init_new_context_ldt(mm); return 0; @@ -170,6 +177,7 @@ static inline int init_new_context(struct task_struct *tsk, static inline void destroy_context(struct mm_struct *mm) { destroy_context_ldt(mm); + mm_free_global_asid(mm); } extern void switch_mm(struct mm_struct *prev, struct mm_struct *next, diff --git a/arch/x86/include/asm/msr-index.h b/arch/x86/include/asm/msr-index.h index fc2634cc48fd..22892a37c849 100644 --- a/arch/x86/include/asm/msr-index.h +++ b/arch/x86/include/asm/msr-index.h @@ -25,6 +25,7 @@ #define _EFER_SVME 12 /* Enable virtualization */ #define _EFER_LMSLE 13 /* Long Mode Segment Limit Enable */ #define _EFER_FFXSR 14 /* Enable Fast FXSAVE/FXRSTOR */ +#define _EFER_TCE 15 /* Enable Translation Cache Extensions */ #define _EFER_AUTOIBRS 21 /* Enable Automatic IBRS */ #define EFER_SCE (1<<_EFER_SCE) @@ -34,6 +35,7 @@ #define EFER_SVME (1<<_EFER_SVME) #define EFER_LMSLE (1<<_EFER_LMSLE) #define EFER_FFXSR (1<<_EFER_FFXSR) +#define EFER_TCE (1<<_EFER_TCE) #define EFER_AUTOIBRS (1<<_EFER_AUTOIBRS) /* diff --git a/arch/x86/include/asm/paravirt.h b/arch/x86/include/asm/paravirt.h index 29e7331a0c98..0c11641e6547 100644 --- a/arch/x86/include/asm/paravirt.h +++ b/arch/x86/include/asm/paravirt.h @@ -91,11 +91,6 @@ static inline void __flush_tlb_multi(const struct cpumask *cpumask, PVOP_VCALL2(mmu.flush_tlb_multi, cpumask, info); } -static inline void paravirt_tlb_remove_table(struct mmu_gather *tlb, void *table) -{ - PVOP_VCALL2(mmu.tlb_remove_table, tlb, table); -} - static inline void paravirt_arch_exit_mmap(struct mm_struct *mm) { PVOP_VCALL1(mmu.exit_mmap, mm); diff --git a/arch/x86/include/asm/paravirt_types.h b/arch/x86/include/asm/paravirt_types.h index abccfccc2e3f..42990293fd5f 100644 --- a/arch/x86/include/asm/paravirt_types.h +++ b/arch/x86/include/asm/paravirt_types.h @@ -133,8 +133,6 @@ struct pv_mmu_ops { void (*flush_tlb_multi)(const struct cpumask *cpus, const struct flush_tlb_info *info); - void (*tlb_remove_table)(struct mmu_gather *tlb, void *table); - /* Hook for intercepting the destruction of an mm_struct. */ void (*exit_mmap)(struct mm_struct *mm); void (*notify_page_enc_status_changed)(unsigned long pfn, int npages, bool enc); diff --git a/arch/x86/include/asm/tlb.h b/arch/x86/include/asm/tlb.h index 77f52bc1578a..866ea78ba156 100644 --- a/arch/x86/include/asm/tlb.h +++ b/arch/x86/include/asm/tlb.h @@ -6,6 +6,9 @@ static inline void tlb_flush(struct mmu_gather *tlb); #include +#include +#include +#include static inline void tlb_flush(struct mmu_gather *tlb) { @@ -25,4 +28,139 @@ static inline void invlpg(unsigned long addr) asm volatile("invlpg (%0)" ::"r" (addr) : "memory"); } +enum addr_stride { + PTE_STRIDE = 0, + PMD_STRIDE = 1 +}; + +/* + * INVLPGB can be targeted by virtual address, PCID, ASID, or any combination + * of the three. For example: + * - FLAG_VA | FLAG_INCLUDE_GLOBAL: invalidate all TLB entries at the address + * - FLAG_PCID: invalidate all TLB entries matching the PCID + * + * The first is used to invalidate (kernel) mappings at a particular + * address across all processes. + * + * The latter invalidates all TLB entries matching a PCID. + */ +#define INVLPGB_FLAG_VA BIT(0) +#define INVLPGB_FLAG_PCID BIT(1) +#define INVLPGB_FLAG_ASID BIT(2) +#define INVLPGB_FLAG_INCLUDE_GLOBAL BIT(3) +#define INVLPGB_FLAG_FINAL_ONLY BIT(4) +#define INVLPGB_FLAG_INCLUDE_NESTED BIT(5) + +/* The implied mode when all bits are clear: */ +#define INVLPGB_MODE_ALL_NONGLOBALS 0UL + +#ifdef CONFIG_BROADCAST_TLB_FLUSH +/* + * INVLPGB does broadcast TLB invalidation across all the CPUs in the system. + * + * The INVLPGB instruction is weakly ordered, and a batch of invalidations can + * be done in a parallel fashion. + * + * The instruction takes the number of extra pages to invalidate, beyond the + * first page, while __invlpgb gets the more human readable number of pages to + * invalidate. + * + * The bits in rax[0:2] determine respectively which components of the address + * (VA, PCID, ASID) get compared when flushing. If neither bits are set, *any* + * address in the specified range matches. + * + * Since it is desired to only flush TLB entries for the ASID that is executing + * the instruction (a host/hypervisor or a guest), the ASID valid bit should + * always be set. On a host/hypervisor, the hardware will use the ASID value + * specified in EDX[15:0] (which should be 0). On a guest, the hardware will + * use the actual ASID value of the guest. + * + * TLBSYNC is used to ensure that pending INVLPGB invalidations initiated from + * this CPU have completed. + */ +static inline void __invlpgb(unsigned long asid, unsigned long pcid, + unsigned long addr, u16 nr_pages, + enum addr_stride stride, u8 flags) +{ + u64 rax = addr | flags | INVLPGB_FLAG_ASID; + u32 ecx = (stride << 31) | (nr_pages - 1); + u32 edx = (pcid << 16) | asid; + + /* The low bits in rax are for flags. Verify addr is clean. */ + VM_WARN_ON_ONCE(addr & ~PAGE_MASK); + + /* INVLPGB; supported in binutils >= 2.36. */ + asm volatile(".byte 0x0f, 0x01, 0xfe" :: "a" (rax), "c" (ecx), "d" (edx)); +} + +static inline void __invlpgb_all(unsigned long asid, unsigned long pcid, u8 flags) +{ + __invlpgb(asid, pcid, 0, 1, 0, flags); +} + +static inline void __tlbsync(void) +{ + /* + * TLBSYNC waits for INVLPGB instructions originating on the same CPU + * to have completed. Print a warning if the task has been migrated, + * and might not be waiting on all the INVLPGBs issued during this TLB + * invalidation sequence. + */ + cant_migrate(); + + /* TLBSYNC: supported in binutils >= 0.36. */ + asm volatile(".byte 0x0f, 0x01, 0xff" ::: "memory"); +} +#else +/* Some compilers (I'm looking at you clang!) simply can't do DCE */ +static inline void __invlpgb(unsigned long asid, unsigned long pcid, + unsigned long addr, u16 nr_pages, + enum addr_stride s, u8 flags) { } +static inline void __invlpgb_all(unsigned long asid, unsigned long pcid, u8 flags) { } +static inline void __tlbsync(void) { } +#endif + +static inline void invlpgb_flush_user_nr_nosync(unsigned long pcid, + unsigned long addr, + u16 nr, bool stride) +{ + enum addr_stride str = stride ? PMD_STRIDE : PTE_STRIDE; + u8 flags = INVLPGB_FLAG_PCID | INVLPGB_FLAG_VA; + + __invlpgb(0, pcid, addr, nr, str, flags); +} + +/* Flush all mappings for a given PCID, not including globals. */ +static inline void invlpgb_flush_single_pcid_nosync(unsigned long pcid) +{ + __invlpgb_all(0, pcid, INVLPGB_FLAG_PCID); +} + +/* Flush all mappings, including globals, for all PCIDs. */ +static inline void invlpgb_flush_all(void) +{ + /* + * TLBSYNC at the end needs to make sure all flushes done on the + * current CPU have been executed system-wide. Therefore, make + * sure nothing gets migrated in-between but disable preemption + * as it is cheaper. + */ + guard(preempt)(); + __invlpgb_all(0, 0, INVLPGB_FLAG_INCLUDE_GLOBAL); + __tlbsync(); +} + +/* Flush addr, including globals, for all PCIDs. */ +static inline void invlpgb_flush_addr_nosync(unsigned long addr, u16 nr) +{ + __invlpgb(0, 0, addr, nr, PTE_STRIDE, INVLPGB_FLAG_INCLUDE_GLOBAL); +} + +/* Flush all mappings for all PCIDs except globals. */ +static inline void invlpgb_flush_all_nonglobals(void) +{ + guard(preempt)(); + __invlpgb_all(0, 0, INVLPGB_MODE_ALL_NONGLOBALS); + __tlbsync(); +} #endif /* _ASM_X86_TLB_H */ diff --git a/arch/x86/include/asm/tlbbatch.h b/arch/x86/include/asm/tlbbatch.h index 1ad56eb3e8a8..80aaf64ff25f 100644 --- a/arch/x86/include/asm/tlbbatch.h +++ b/arch/x86/include/asm/tlbbatch.h @@ -10,6 +10,11 @@ struct arch_tlbflush_unmap_batch { * the PFNs being flushed.. */ struct cpumask cpumask; + /* + * Set if pages were unmapped from any MM, even one that does not + * have active CPUs in its cpumask. + */ + bool unmapped_pages; }; #endif /* _ARCH_X86_TLBBATCH_H */ diff --git a/arch/x86/include/asm/tlbflush.h b/arch/x86/include/asm/tlbflush.h index 3da645139748..a9af8759de34 100644 --- a/arch/x86/include/asm/tlbflush.h +++ b/arch/x86/include/asm/tlbflush.h @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -183,6 +184,9 @@ static inline void cr4_init_shadow(void) extern unsigned long mmu_cr4_features; extern u32 *trampoline_cr4_features; +/* How many pages can be invalidated with one INVLPGB. */ +extern u16 invlpgb_count_max; + extern void initialize_tlbstate_and_flush(void); /* @@ -231,6 +235,71 @@ void flush_tlb_one_kernel(unsigned long addr); void flush_tlb_multi(const struct cpumask *cpumask, const struct flush_tlb_info *info); +static inline bool is_dyn_asid(u16 asid) +{ + return asid < TLB_NR_DYN_ASIDS; +} + +static inline bool is_global_asid(u16 asid) +{ + return !is_dyn_asid(asid); +} + +#ifdef CONFIG_BROADCAST_TLB_FLUSH +static inline u16 mm_global_asid(struct mm_struct *mm) +{ + u16 asid; + + if (!cpu_feature_enabled(X86_FEATURE_INVLPGB)) + return 0; + + asid = smp_load_acquire(&mm->context.global_asid); + + /* mm->context.global_asid is either 0, or a global ASID */ + VM_WARN_ON_ONCE(asid && is_dyn_asid(asid)); + + return asid; +} + +static inline void mm_init_global_asid(struct mm_struct *mm) +{ + if (cpu_feature_enabled(X86_FEATURE_INVLPGB)) { + mm->context.global_asid = 0; + mm->context.asid_transition = false; + } +} + +static inline void mm_assign_global_asid(struct mm_struct *mm, u16 asid) +{ + /* + * Notably flush_tlb_mm_range() -> broadcast_tlb_flush() -> + * finish_asid_transition() needs to observe asid_transition = true + * once it observes global_asid. + */ + mm->context.asid_transition = true; + smp_store_release(&mm->context.global_asid, asid); +} + +static inline void mm_clear_asid_transition(struct mm_struct *mm) +{ + WRITE_ONCE(mm->context.asid_transition, false); +} + +static inline bool mm_in_asid_transition(struct mm_struct *mm) +{ + if (!cpu_feature_enabled(X86_FEATURE_INVLPGB)) + return false; + + return mm && READ_ONCE(mm->context.asid_transition); +} +#else +static inline u16 mm_global_asid(struct mm_struct *mm) { return 0; } +static inline void mm_init_global_asid(struct mm_struct *mm) { } +static inline void mm_assign_global_asid(struct mm_struct *mm, u16 asid) { } +static inline void mm_clear_asid_transition(struct mm_struct *mm) { } +static inline bool mm_in_asid_transition(struct mm_struct *mm) { return false; } +#endif /* CONFIG_BROADCAST_TLB_FLUSH */ + #ifdef CONFIG_PARAVIRT #include #endif @@ -284,6 +353,7 @@ static inline void arch_tlbbatch_add_pending(struct arch_tlbflush_unmap_batch *b { inc_mm_tlb_gen(mm); cpumask_or(&batch->cpumask, &batch->cpumask, mm_cpumask(mm)); + batch->unmapped_pages = true; mmu_notifier_arch_invalidate_secondary_tlbs(mm, 0, -1UL); } diff --git a/arch/x86/kernel/cpu/amd.c b/arch/x86/kernel/cpu/amd.c index 4c9b20d028eb..b1146c1b35a7 100644 --- a/arch/x86/kernel/cpu/amd.c +++ b/arch/x86/kernel/cpu/amd.c @@ -29,6 +29,8 @@ #include "cpu.h" +u16 invlpgb_count_max __ro_after_init; + static inline int rdmsrl_amd_safe(unsigned msr, unsigned long long *p) { u32 gprs[8] = { 0 }; @@ -1074,6 +1076,10 @@ static void init_amd(struct cpuinfo_x86 *c) /* AMD CPUs don't need fencing after x2APIC/TSC_DEADLINE MSR writes. */ clear_cpu_cap(c, X86_FEATURE_APIC_MSRS_FENCE); + + /* Enable Translation Cache Extension */ + if (cpu_has(c, X86_FEATURE_TCE)) + msr_set_bit(MSR_EFER, _EFER_TCE); } #ifdef CONFIG_X86_32 @@ -1140,6 +1146,10 @@ static void cpu_detect_tlb_amd(struct cpuinfo_x86 *c) tlb_lli_2m[ENTRIES] = eax & mask; tlb_lli_4m[ENTRIES] = tlb_lli_2m[ENTRIES] >> 1; + + /* Max number of pages INVLPGB can invalidate in one shot */ + if (cpu_has(c, X86_FEATURE_INVLPGB)) + invlpgb_count_max = (cpuid_edx(0x80000008) & 0xffff) + 1; } static const struct cpu_dev amd_cpu_dev = { diff --git a/arch/x86/kernel/kvm.c b/arch/x86/kernel/kvm.c index 7a422a6c5983..3be9b3342c67 100644 --- a/arch/x86/kernel/kvm.c +++ b/arch/x86/kernel/kvm.c @@ -838,7 +838,6 @@ static void __init kvm_guest_init(void) #ifdef CONFIG_SMP if (pv_tlb_flush_supported()) { pv_ops.mmu.flush_tlb_multi = kvm_flush_tlb_multi; - pv_ops.mmu.tlb_remove_table = tlb_remove_table; pr_info("KVM setup pv remote TLB flush\n"); } diff --git a/arch/x86/kernel/paravirt.c b/arch/x86/kernel/paravirt.c index c5bb980b8a67..3f4864501d9e 100644 --- a/arch/x86/kernel/paravirt.c +++ b/arch/x86/kernel/paravirt.c @@ -59,21 +59,6 @@ void __init native_pv_lock_init(void) static_branch_enable(&virt_spin_lock_key); } -#ifndef CONFIG_PT_RECLAIM -static void native_tlb_remove_table(struct mmu_gather *tlb, void *table) -{ - struct ptdesc *ptdesc = (struct ptdesc *)table; - - pagetable_dtor(ptdesc); - tlb_remove_page(tlb, ptdesc_page(ptdesc)); -} -#else -static void native_tlb_remove_table(struct mmu_gather *tlb, void *table) -{ - tlb_remove_table(tlb, table); -} -#endif - struct static_key paravirt_steal_enabled; struct static_key paravirt_steal_rq_enabled; @@ -197,7 +182,6 @@ struct paravirt_patch_template pv_ops = { .mmu.flush_tlb_kernel = native_flush_tlb_global, .mmu.flush_tlb_one_user = native_flush_tlb_one_user, .mmu.flush_tlb_multi = native_flush_tlb_multi, - .mmu.tlb_remove_table = native_tlb_remove_table, .mmu.exit_mmap = paravirt_nop, .mmu.notify_page_enc_status_changed = paravirt_nop, diff --git a/arch/x86/mm/pgtable.c b/arch/x86/mm/pgtable.c index 1fef5ad32d5a..b1c1f72c1fd1 100644 --- a/arch/x86/mm/pgtable.c +++ b/arch/x86/mm/pgtable.c @@ -18,25 +18,6 @@ EXPORT_SYMBOL(physical_mask); #define PGTABLE_HIGHMEM 0 #endif -#ifndef CONFIG_PARAVIRT -#ifndef CONFIG_PT_RECLAIM -static inline -void paravirt_tlb_remove_table(struct mmu_gather *tlb, void *table) -{ - struct ptdesc *ptdesc = (struct ptdesc *)table; - - pagetable_dtor(ptdesc); - tlb_remove_page(tlb, ptdesc_page(ptdesc)); -} -#else -static inline -void paravirt_tlb_remove_table(struct mmu_gather *tlb, void *table) -{ - tlb_remove_table(tlb, table); -} -#endif /* !CONFIG_PT_RECLAIM */ -#endif /* !CONFIG_PARAVIRT */ - gfp_t __userpte_alloc_gfp = GFP_PGTABLE_USER | PGTABLE_HIGHMEM; pgtable_t pte_alloc_one(struct mm_struct *mm) @@ -64,7 +45,7 @@ early_param("userpte", setup_userpte); void ___pte_free_tlb(struct mmu_gather *tlb, struct page *pte) { paravirt_release_pte(page_to_pfn(pte)); - paravirt_tlb_remove_table(tlb, page_ptdesc(pte)); + tlb_remove_table(tlb, page_ptdesc(pte)); } #if CONFIG_PGTABLE_LEVELS > 2 @@ -78,21 +59,21 @@ void ___pmd_free_tlb(struct mmu_gather *tlb, pmd_t *pmd) #ifdef CONFIG_X86_PAE tlb->need_flush_all = 1; #endif - paravirt_tlb_remove_table(tlb, virt_to_ptdesc(pmd)); + tlb_remove_table(tlb, virt_to_ptdesc(pmd)); } #if CONFIG_PGTABLE_LEVELS > 3 void ___pud_free_tlb(struct mmu_gather *tlb, pud_t *pud) { paravirt_release_pud(__pa(pud) >> PAGE_SHIFT); - paravirt_tlb_remove_table(tlb, virt_to_ptdesc(pud)); + tlb_remove_table(tlb, virt_to_ptdesc(pud)); } #if CONFIG_PGTABLE_LEVELS > 4 void ___p4d_free_tlb(struct mmu_gather *tlb, p4d_t *p4d) { paravirt_release_p4d(__pa(p4d) >> PAGE_SHIFT); - paravirt_tlb_remove_table(tlb, virt_to_ptdesc(p4d)); + tlb_remove_table(tlb, virt_to_ptdesc(p4d)); } #endif /* CONFIG_PGTABLE_LEVELS > 4 */ #endif /* CONFIG_PGTABLE_LEVELS > 3 */ diff --git a/arch/x86/mm/tlb.c b/arch/x86/mm/tlb.c index 6cf881a942bb..296346b638bc 100644 --- a/arch/x86/mm/tlb.c +++ b/arch/x86/mm/tlb.c @@ -74,13 +74,15 @@ * use different names for each of them: * * ASID - [0, TLB_NR_DYN_ASIDS-1] - * the canonical identifier for an mm + * the canonical identifier for an mm, dynamically allocated on each CPU + * [TLB_NR_DYN_ASIDS, MAX_ASID_AVAILABLE-1] + * the canonical, global identifier for an mm, identical across all CPUs * - * kPCID - [1, TLB_NR_DYN_ASIDS] + * kPCID - [1, MAX_ASID_AVAILABLE] * the value we write into the PCID part of CR3; corresponds to the * ASID+1, because PCID 0 is special. * - * uPCID - [2048 + 1, 2048 + TLB_NR_DYN_ASIDS] + * uPCID - [2048 + 1, 2048 + MAX_ASID_AVAILABLE] * for KPTI each mm has two address spaces and thus needs two * PCID values, but we can still do with a single ASID denomination * for each mm. Corresponds to kPCID + 2048. @@ -225,6 +227,20 @@ static void choose_new_asid(struct mm_struct *next, u64 next_tlb_gen, return; } + /* + * TLB consistency for global ASIDs is maintained with hardware assisted + * remote TLB flushing. Global ASIDs are always up to date. + */ + if (cpu_feature_enabled(X86_FEATURE_INVLPGB)) { + u16 global_asid = mm_global_asid(next); + + if (global_asid) { + *new_asid = global_asid; + *need_flush = false; + return; + } + } + if (this_cpu_read(cpu_tlbstate.invalidate_other)) clear_asid_other(); @@ -251,6 +267,268 @@ static void choose_new_asid(struct mm_struct *next, u64 next_tlb_gen, *need_flush = true; } +/* + * Global ASIDs are allocated for multi-threaded processes that are + * active on multiple CPUs simultaneously, giving each of those + * processes the same PCID on every CPU, for use with hardware-assisted + * TLB shootdown on remote CPUs, like AMD INVLPGB or Intel RAR. + * + * These global ASIDs are held for the lifetime of the process. + */ +static DEFINE_RAW_SPINLOCK(global_asid_lock); +static u16 last_global_asid = MAX_ASID_AVAILABLE; +static DECLARE_BITMAP(global_asid_used, MAX_ASID_AVAILABLE); +static DECLARE_BITMAP(global_asid_freed, MAX_ASID_AVAILABLE); +static int global_asid_available = MAX_ASID_AVAILABLE - TLB_NR_DYN_ASIDS - 1; + +/* + * When the search for a free ASID in the global ASID space reaches + * MAX_ASID_AVAILABLE, a global TLB flush guarantees that previously + * freed global ASIDs are safe to re-use. + * + * This way the global flush only needs to happen at ASID rollover + * time, and not at ASID allocation time. + */ +static void reset_global_asid_space(void) +{ + lockdep_assert_held(&global_asid_lock); + + invlpgb_flush_all_nonglobals(); + + /* + * The TLB flush above makes it safe to re-use the previously + * freed global ASIDs. + */ + bitmap_andnot(global_asid_used, global_asid_used, + global_asid_freed, MAX_ASID_AVAILABLE); + bitmap_clear(global_asid_freed, 0, MAX_ASID_AVAILABLE); + + /* Restart the search from the start of global ASID space. */ + last_global_asid = TLB_NR_DYN_ASIDS; +} + +static u16 allocate_global_asid(void) +{ + u16 asid; + + lockdep_assert_held(&global_asid_lock); + + /* The previous allocation hit the edge of available address space */ + if (last_global_asid >= MAX_ASID_AVAILABLE - 1) + reset_global_asid_space(); + + asid = find_next_zero_bit(global_asid_used, MAX_ASID_AVAILABLE, last_global_asid); + + if (asid >= MAX_ASID_AVAILABLE && !global_asid_available) { + /* This should never happen. */ + VM_WARN_ONCE(1, "Unable to allocate global ASID despite %d available\n", + global_asid_available); + return 0; + } + + /* Claim this global ASID. */ + __set_bit(asid, global_asid_used); + last_global_asid = asid; + global_asid_available--; + return asid; +} + +/* + * Check whether a process is currently active on more than @threshold CPUs. + * This is a cheap estimation on whether or not it may make sense to assign + * a global ASID to this process, and use broadcast TLB invalidation. + */ +static bool mm_active_cpus_exceeds(struct mm_struct *mm, int threshold) +{ + int count = 0; + int cpu; + + /* This quick check should eliminate most single threaded programs. */ + if (cpumask_weight(mm_cpumask(mm)) <= threshold) + return false; + + /* Slower check to make sure. */ + for_each_cpu(cpu, mm_cpumask(mm)) { + /* Skip the CPUs that aren't really running this process. */ + if (per_cpu(cpu_tlbstate.loaded_mm, cpu) != mm) + continue; + + if (per_cpu(cpu_tlbstate_shared.is_lazy, cpu)) + continue; + + if (++count > threshold) + return true; + } + return false; +} + +/* + * Assign a global ASID to the current process, protecting against + * races between multiple threads in the process. + */ +static void use_global_asid(struct mm_struct *mm) +{ + u16 asid; + + guard(raw_spinlock_irqsave)(&global_asid_lock); + + /* This process is already using broadcast TLB invalidation. */ + if (mm_global_asid(mm)) + return; + + /* + * The last global ASID was consumed while waiting for the lock. + * + * If this fires, a more aggressive ASID reuse scheme might be + * needed. + */ + if (!global_asid_available) { + VM_WARN_ONCE(1, "Ran out of global ASIDs\n"); + return; + } + + asid = allocate_global_asid(); + if (!asid) + return; + + mm_assign_global_asid(mm, asid); +} + +void mm_free_global_asid(struct mm_struct *mm) +{ + if (!cpu_feature_enabled(X86_FEATURE_INVLPGB)) + return; + + if (!mm_global_asid(mm)) + return; + + guard(raw_spinlock_irqsave)(&global_asid_lock); + + /* The global ASID can be re-used only after flush at wrap-around. */ +#ifdef CONFIG_BROADCAST_TLB_FLUSH + __set_bit(mm->context.global_asid, global_asid_freed); + + mm->context.global_asid = 0; + global_asid_available++; +#endif +} + +/* + * Is the mm transitioning from a CPU-local ASID to a global ASID? + */ +static bool mm_needs_global_asid(struct mm_struct *mm, u16 asid) +{ + u16 global_asid = mm_global_asid(mm); + + if (!cpu_feature_enabled(X86_FEATURE_INVLPGB)) + return false; + + /* Process is transitioning to a global ASID */ + if (global_asid && asid != global_asid) + return true; + + return false; +} + +/* + * x86 has 4k ASIDs (2k when compiled with KPTI), but the largest x86 + * systems have over 8k CPUs. Because of this potential ASID shortage, + * global ASIDs are handed out to processes that have frequent TLB + * flushes and are active on 4 or more CPUs simultaneously. + */ +static void consider_global_asid(struct mm_struct *mm) +{ + if (!cpu_feature_enabled(X86_FEATURE_INVLPGB)) + return; + + /* Check every once in a while. */ + if ((current->pid & 0x1f) != (jiffies & 0x1f)) + return; + + /* + * Assign a global ASID if the process is active on + * 4 or more CPUs simultaneously. + */ + if (mm_active_cpus_exceeds(mm, 3)) + use_global_asid(mm); +} + +static void finish_asid_transition(struct flush_tlb_info *info) +{ + struct mm_struct *mm = info->mm; + int bc_asid = mm_global_asid(mm); + int cpu; + + if (!mm_in_asid_transition(mm)) + return; + + for_each_cpu(cpu, mm_cpumask(mm)) { + /* + * The remote CPU is context switching. Wait for that to + * finish, to catch the unlikely case of it switching to + * the target mm with an out of date ASID. + */ + while (READ_ONCE(per_cpu(cpu_tlbstate.loaded_mm, cpu)) == LOADED_MM_SWITCHING) + cpu_relax(); + + if (READ_ONCE(per_cpu(cpu_tlbstate.loaded_mm, cpu)) != mm) + continue; + + /* + * If at least one CPU is not using the global ASID yet, + * send a TLB flush IPI. The IPI should cause stragglers + * to transition soon. + * + * This can race with the CPU switching to another task; + * that results in a (harmless) extra IPI. + */ + if (READ_ONCE(per_cpu(cpu_tlbstate.loaded_mm_asid, cpu)) != bc_asid) { + flush_tlb_multi(mm_cpumask(info->mm), info); + return; + } + } + + /* All the CPUs running this process are using the global ASID. */ + mm_clear_asid_transition(mm); +} + +static void broadcast_tlb_flush(struct flush_tlb_info *info) +{ + bool pmd = info->stride_shift == PMD_SHIFT; + unsigned long asid = mm_global_asid(info->mm); + unsigned long addr = info->start; + + /* + * TLB flushes with INVLPGB are kicked off asynchronously. + * The inc_mm_tlb_gen() guarantees page table updates are done + * before these TLB flushes happen. + */ + if (info->end == TLB_FLUSH_ALL) { + invlpgb_flush_single_pcid_nosync(kern_pcid(asid)); + /* Do any CPUs supporting INVLPGB need PTI? */ + if (cpu_feature_enabled(X86_FEATURE_PTI)) + invlpgb_flush_single_pcid_nosync(user_pcid(asid)); + } else do { + unsigned long nr = 1; + + if (info->stride_shift <= PMD_SHIFT) { + nr = (info->end - addr) >> info->stride_shift; + nr = clamp_val(nr, 1, invlpgb_count_max); + } + + invlpgb_flush_user_nr_nosync(kern_pcid(asid), addr, nr, pmd); + if (cpu_feature_enabled(X86_FEATURE_PTI)) + invlpgb_flush_user_nr_nosync(user_pcid(asid), addr, nr, pmd); + + addr += nr << info->stride_shift; + } while (addr < info->end); + + finish_asid_transition(info); + + /* Wait for the INVLPGBs kicked off above to finish. */ + __tlbsync(); +} + /* * Given an ASID, flush the corresponding user ASID. We can delay this * until the next time we switch to it. @@ -556,7 +834,8 @@ void switch_mm_irqs_off(struct mm_struct *unused, struct mm_struct *next, */ if (prev == next) { /* Not actually switching mm's */ - VM_WARN_ON(this_cpu_read(cpu_tlbstate.ctxs[prev_asid].ctx_id) != + VM_WARN_ON(is_dyn_asid(prev_asid) && + this_cpu_read(cpu_tlbstate.ctxs[prev_asid].ctx_id) != next->context.ctx_id); /* @@ -573,6 +852,20 @@ void switch_mm_irqs_off(struct mm_struct *unused, struct mm_struct *next, !cpumask_test_cpu(cpu, mm_cpumask(next)))) cpumask_set_cpu(cpu, mm_cpumask(next)); + /* Check if the current mm is transitioning to a global ASID */ + if (mm_needs_global_asid(next, prev_asid)) { + next_tlb_gen = atomic64_read(&next->context.tlb_gen); + choose_new_asid(next, next_tlb_gen, &new_asid, &need_flush); + goto reload_tlb; + } + + /* + * Broadcast TLB invalidation keeps this ASID up to date + * all the time. + */ + if (is_global_asid(prev_asid)) + return; + /* * If the CPU is not in lazy TLB mode, we are just switching * from one thread in a process to another thread in the same @@ -606,6 +899,13 @@ void switch_mm_irqs_off(struct mm_struct *unused, struct mm_struct *next, */ cond_mitigation(tsk); + /* + * Let nmi_uaccess_okay() and finish_asid_transition() + * know that CR3 is changing. + */ + this_cpu_write(cpu_tlbstate.loaded_mm, LOADED_MM_SWITCHING); + barrier(); + /* * Leave this CPU in prev's mm_cpumask. Atomic writes to * mm_cpumask can be expensive under contention. The CPU @@ -620,14 +920,12 @@ void switch_mm_irqs_off(struct mm_struct *unused, struct mm_struct *next, next_tlb_gen = atomic64_read(&next->context.tlb_gen); choose_new_asid(next, next_tlb_gen, &new_asid, &need_flush); - - /* Let nmi_uaccess_okay() know that we're changing CR3. */ - this_cpu_write(cpu_tlbstate.loaded_mm, LOADED_MM_SWITCHING); - barrier(); } +reload_tlb: new_lam = mm_lam_cr3_mask(next); if (need_flush) { + VM_WARN_ON_ONCE(is_global_asid(new_asid)); this_cpu_write(cpu_tlbstate.ctxs[new_asid].ctx_id, next->context.ctx_id); this_cpu_write(cpu_tlbstate.ctxs[new_asid].tlb_gen, next_tlb_gen); load_new_mm_cr3(next->pgd, new_asid, new_lam, true); @@ -746,7 +1044,7 @@ static void flush_tlb_func(void *info) const struct flush_tlb_info *f = info; struct mm_struct *loaded_mm = this_cpu_read(cpu_tlbstate.loaded_mm); u32 loaded_mm_asid = this_cpu_read(cpu_tlbstate.loaded_mm_asid); - u64 local_tlb_gen = this_cpu_read(cpu_tlbstate.ctxs[loaded_mm_asid].tlb_gen); + u64 local_tlb_gen; bool local = smp_processor_id() == f->initiating_cpu; unsigned long nr_invalidate = 0; u64 mm_tlb_gen; @@ -769,6 +1067,16 @@ static void flush_tlb_func(void *info) if (unlikely(loaded_mm == &init_mm)) return; + /* Reload the ASID if transitioning into or out of a global ASID */ + if (mm_needs_global_asid(loaded_mm, loaded_mm_asid)) { + switch_mm_irqs_off(NULL, loaded_mm, NULL); + loaded_mm_asid = this_cpu_read(cpu_tlbstate.loaded_mm_asid); + } + + /* Broadcast ASIDs are always kept up to date with INVLPGB. */ + if (is_global_asid(loaded_mm_asid)) + return; + VM_WARN_ON(this_cpu_read(cpu_tlbstate.ctxs[loaded_mm_asid].ctx_id) != loaded_mm->context.ctx_id); @@ -786,6 +1094,8 @@ static void flush_tlb_func(void *info) return; } + local_tlb_gen = this_cpu_read(cpu_tlbstate.ctxs[loaded_mm_asid].tlb_gen); + if (unlikely(f->new_tlb_gen != TLB_GENERATION_INVALID && f->new_tlb_gen <= local_tlb_gen)) { /* @@ -953,7 +1263,7 @@ STATIC_NOPV void native_flush_tlb_multi(const struct cpumask *cpumask, * up on the new contents of what used to be page tables, while * doing a speculative memory access. */ - if (info->freed_tables) + if (info->freed_tables || mm_in_asid_transition(info->mm)) on_each_cpu_mask(cpumask, flush_tlb_func, (void *)info, true); else on_each_cpu_cond_mask(should_flush_tlb, flush_tlb_func, @@ -1000,6 +1310,15 @@ static struct flush_tlb_info *get_flush_tlb_info(struct mm_struct *mm, BUG_ON(this_cpu_inc_return(flush_tlb_info_idx) != 1); #endif + /* + * If the number of flushes is so large that a full flush + * would be faster, do a full flush. + */ + if ((end - start) >> stride_shift > tlb_single_page_flush_ceiling) { + start = 0; + end = TLB_FLUSH_ALL; + } + info->start = start; info->end = end; info->mm = mm; @@ -1026,17 +1345,8 @@ void flush_tlb_mm_range(struct mm_struct *mm, unsigned long start, bool freed_tables) { struct flush_tlb_info *info; + int cpu = get_cpu(); u64 new_tlb_gen; - int cpu; - - cpu = get_cpu(); - - /* Should we flush just the requested range? */ - if ((end == TLB_FLUSH_ALL) || - ((end - start) >> stride_shift) > tlb_single_page_flush_ceiling) { - start = 0; - end = TLB_FLUSH_ALL; - } /* This is also a barrier that synchronizes with switch_mm(). */ new_tlb_gen = inc_mm_tlb_gen(mm); @@ -1049,9 +1359,12 @@ void flush_tlb_mm_range(struct mm_struct *mm, unsigned long start, * a local TLB flush is needed. Optimize this use-case by calling * flush_tlb_func_local() directly in this case. */ - if (cpumask_any_but(mm_cpumask(mm), cpu) < nr_cpu_ids) { + if (mm_global_asid(mm)) { + broadcast_tlb_flush(info); + } else if (cpumask_any_but(mm_cpumask(mm), cpu) < nr_cpu_ids) { info->trim_cpumask = should_trim_cpumask(mm); flush_tlb_multi(mm_cpumask(mm), info); + consider_global_asid(mm); } else if (mm == this_cpu_read(cpu_tlbstate.loaded_mm)) { lockdep_assert_irqs_enabled(); local_irq_disable(); @@ -1064,7 +1377,6 @@ void flush_tlb_mm_range(struct mm_struct *mm, unsigned long start, mmu_notifier_arch_invalidate_secondary_tlbs(mm, start, end); } - static void do_flush_tlb_all(void *info) { count_vm_tlb_event(NR_TLB_REMOTE_FLUSH_RECEIVED); @@ -1074,7 +1386,32 @@ static void do_flush_tlb_all(void *info) void flush_tlb_all(void) { count_vm_tlb_event(NR_TLB_REMOTE_FLUSH); - on_each_cpu(do_flush_tlb_all, NULL, 1); + + /* First try (faster) hardware-assisted TLB invalidation. */ + if (cpu_feature_enabled(X86_FEATURE_INVLPGB)) + invlpgb_flush_all(); + else + /* Fall back to the IPI-based invalidation. */ + on_each_cpu(do_flush_tlb_all, NULL, 1); +} + +/* Flush an arbitrarily large range of memory with INVLPGB. */ +static void invlpgb_kernel_range_flush(struct flush_tlb_info *info) +{ + unsigned long addr, nr; + + for (addr = info->start; addr < info->end; addr += nr << PAGE_SHIFT) { + nr = (info->end - addr) >> PAGE_SHIFT; + + /* + * INVLPGB has a limit on the size of ranges it can + * flush. Break up large flushes. + */ + nr = clamp_val(nr, 1, invlpgb_count_max); + + invlpgb_flush_addr_nosync(addr, nr); + } + __tlbsync(); } static void do_kernel_range_flush(void *info) @@ -1087,24 +1424,37 @@ static void do_kernel_range_flush(void *info) flush_tlb_one_kernel(addr); } -void flush_tlb_kernel_range(unsigned long start, unsigned long end) +static void kernel_tlb_flush_all(struct flush_tlb_info *info) { - /* Balance as user space task's flush, a bit conservative */ - if (end == TLB_FLUSH_ALL || - (end - start) > tlb_single_page_flush_ceiling << PAGE_SHIFT) { + if (cpu_feature_enabled(X86_FEATURE_INVLPGB)) + invlpgb_flush_all(); + else on_each_cpu(do_flush_tlb_all, NULL, 1); - } else { - struct flush_tlb_info *info; - - preempt_disable(); - info = get_flush_tlb_info(NULL, start, end, 0, false, - TLB_GENERATION_INVALID); +} +static void kernel_tlb_flush_range(struct flush_tlb_info *info) +{ + if (cpu_feature_enabled(X86_FEATURE_INVLPGB)) + invlpgb_kernel_range_flush(info); + else on_each_cpu(do_kernel_range_flush, info, 1); +} - put_flush_tlb_info(); - preempt_enable(); - } +void flush_tlb_kernel_range(unsigned long start, unsigned long end) +{ + struct flush_tlb_info *info; + + guard(preempt)(); + + info = get_flush_tlb_info(NULL, start, end, PAGE_SHIFT, false, + TLB_GENERATION_INVALID); + + if (info->end == TLB_FLUSH_ALL) + kernel_tlb_flush_all(info); + else + kernel_tlb_flush_range(info); + + put_flush_tlb_info(); } /* @@ -1283,7 +1633,10 @@ void arch_tlbbatch_flush(struct arch_tlbflush_unmap_batch *batch) * a local TLB flush is needed. Optimize this use-case by calling * flush_tlb_func_local() directly in this case. */ - if (cpumask_any_but(&batch->cpumask, cpu) < nr_cpu_ids) { + if (cpu_feature_enabled(X86_FEATURE_INVLPGB) && batch->unmapped_pages) { + invlpgb_flush_all_nonglobals(); + batch->unmapped_pages = false; + } else if (cpumask_any_but(&batch->cpumask, cpu) < nr_cpu_ids) { flush_tlb_multi(&batch->cpumask, info); } else if (cpumask_test_cpu(cpu, &batch->cpumask)) { lockdep_assert_irqs_enabled(); diff --git a/arch/x86/xen/mmu_pv.c b/arch/x86/xen/mmu_pv.c index d078de2c952b..38971c6dcd4b 100644 --- a/arch/x86/xen/mmu_pv.c +++ b/arch/x86/xen/mmu_pv.c @@ -2189,7 +2189,6 @@ static const typeof(pv_ops) xen_mmu_ops __initconst = { .flush_tlb_kernel = xen_flush_tlb, .flush_tlb_one_user = xen_flush_tlb_one_user, .flush_tlb_multi = xen_flush_tlb_multi, - .tlb_remove_table = tlb_remove_table, .pgd_alloc = xen_pgd_alloc, .pgd_free = xen_pgd_free, diff --git a/tools/arch/x86/include/asm/msr-index.h b/tools/arch/x86/include/asm/msr-index.h index 3ae84c3b8e6d..dc1c1057f26e 100644 --- a/tools/arch/x86/include/asm/msr-index.h +++ b/tools/arch/x86/include/asm/msr-index.h @@ -25,6 +25,7 @@ #define _EFER_SVME 12 /* Enable virtualization */ #define _EFER_LMSLE 13 /* Long Mode Segment Limit Enable */ #define _EFER_FFXSR 14 /* Enable Fast FXSAVE/FXRSTOR */ +#define _EFER_TCE 15 /* Enable Translation Cache Extensions */ #define _EFER_AUTOIBRS 21 /* Enable Automatic IBRS */ #define EFER_SCE (1<<_EFER_SCE) @@ -34,6 +35,7 @@ #define EFER_SVME (1<<_EFER_SVME) #define EFER_LMSLE (1<<_EFER_LMSLE) #define EFER_FFXSR (1<<_EFER_FFXSR) +#define EFER_TCE (1<<_EFER_TCE) #define EFER_AUTOIBRS (1<<_EFER_AUTOIBRS) /* -- 2.49.0.391.g4bbb303af6 From 5a900dc4f34c93794f70f3266ca5a4f1947e6790 Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Sun, 20 Apr 2025 15:31:55 +0200 Subject: [PATCH 3/9] asus Signed-off-by: Peter Jung --- .../ABI/testing/sysfs-platform-asus-wmi | 17 + .../display/dc/dml2/dml2_translation_helper.c | 2 +- drivers/hid/Kconfig | 9 + drivers/hid/Makefile | 1 + drivers/hid/hid-asus-ally.c | 2197 +++++++++++++++++ drivers/hid/hid-asus-ally.h | 398 +++ drivers/hid/hid-asus.c | 130 +- drivers/hid/hid-asus.h | 13 + drivers/hid/hid-ids.h | 1 + drivers/platform/x86/Kconfig | 23 + drivers/platform/x86/Makefile | 1 + drivers/platform/x86/asus-armoury.c | 1202 +++++++++ drivers/platform/x86/asus-armoury.h | 1278 ++++++++++ drivers/platform/x86/asus-wmi.c | 307 ++- include/linux/platform_data/x86/asus-wmi.h | 43 + 15 files changed, 5545 insertions(+), 77 deletions(-) create mode 100644 drivers/hid/hid-asus-ally.c create mode 100644 drivers/hid/hid-asus-ally.h create mode 100644 drivers/hid/hid-asus.h create mode 100644 drivers/platform/x86/asus-armoury.c create mode 100644 drivers/platform/x86/asus-armoury.h diff --git a/Documentation/ABI/testing/sysfs-platform-asus-wmi b/Documentation/ABI/testing/sysfs-platform-asus-wmi index 28144371a0f1..765d50b0d9df 100644 --- a/Documentation/ABI/testing/sysfs-platform-asus-wmi +++ b/Documentation/ABI/testing/sysfs-platform-asus-wmi @@ -63,6 +63,7 @@ Date: Aug 2022 KernelVersion: 6.1 Contact: "Luke Jones" Description: + DEPRECATED, WILL BE REMOVED SOON Switch the GPU hardware MUX mode. Laptops with this feature can can be toggled to boot with only the dGPU (discrete mode) or in standard Optimus/Hybrid mode. On switch a reboot is required: @@ -75,6 +76,7 @@ Date: Aug 2022 KernelVersion: 5.17 Contact: "Luke Jones" Description: + DEPRECATED, WILL BE REMOVED SOON Disable discrete GPU: * 0 - Enable dGPU, * 1 - Disable dGPU @@ -84,6 +86,7 @@ Date: Aug 2022 KernelVersion: 5.17 Contact: "Luke Jones" Description: + DEPRECATED, WILL BE REMOVED SOON Enable the external GPU paired with ROG X-Flow laptops. Toggling this setting will also trigger ACPI to disable the dGPU: @@ -95,6 +98,7 @@ Date: Aug 2022 KernelVersion: 5.17 Contact: "Luke Jones" Description: + DEPRECATED, WILL BE REMOVED SOON Enable an LCD response-time boost to reduce or remove ghosting: * 0 - Disable, * 1 - Enable @@ -104,6 +108,7 @@ Date: Jun 2023 KernelVersion: 6.5 Contact: "Luke Jones" Description: + DEPRECATED, WILL BE REMOVED SOON Get the current charging mode being used: * 1 - Barrel connected charger, * 2 - USB-C charging @@ -114,6 +119,7 @@ Date: Jun 2023 KernelVersion: 6.5 Contact: "Luke Jones" Description: + DEPRECATED, WILL BE REMOVED SOON Show if the egpu (XG Mobile) is correctly connected: * 0 - False, * 1 - True @@ -123,6 +129,7 @@ Date: Jun 2023 KernelVersion: 6.5 Contact: "Luke Jones" Description: + DEPRECATED, WILL BE REMOVED SOON Change the mini-LED mode: * 0 - Single-zone, * 1 - Multi-zone @@ -133,6 +140,7 @@ Date: Apr 2024 KernelVersion: 6.10 Contact: "Luke Jones" Description: + DEPRECATED, WILL BE REMOVED SOON List the available mini-led modes. What: /sys/devices/platform//ppt_pl1_spl @@ -140,6 +148,7 @@ Date: Jun 2023 KernelVersion: 6.5 Contact: "Luke Jones" Description: + DEPRECATED, WILL BE REMOVED SOON Set the Package Power Target total of CPU: PL1 on Intel, SPL on AMD. Shown on Intel+Nvidia or AMD+Nvidia based systems: @@ -150,6 +159,7 @@ Date: Jun 2023 KernelVersion: 6.5 Contact: "Luke Jones" Description: + DEPRECATED, WILL BE REMOVED SOON Set the Slow Package Power Tracking Limit of CPU: PL2 on Intel, SPPT, on AMD. Shown on Intel+Nvidia or AMD+Nvidia based systems: @@ -160,6 +170,7 @@ Date: Jun 2023 KernelVersion: 6.5 Contact: "Luke Jones" Description: + DEPRECATED, WILL BE REMOVED SOON Set the Fast Package Power Tracking Limit of CPU. AMD+Nvidia only: * min=5, max=250 @@ -168,6 +179,7 @@ Date: Jun 2023 KernelVersion: 6.5 Contact: "Luke Jones" Description: + DEPRECATED, WILL BE REMOVED SOON Set the APU SPPT limit. Shown on full AMD systems only: * min=5, max=130 @@ -176,6 +188,7 @@ Date: Jun 2023 KernelVersion: 6.5 Contact: "Luke Jones" Description: + DEPRECATED, WILL BE REMOVED SOON Set the platform SPPT limit. Shown on full AMD systems only: * min=5, max=130 @@ -184,6 +197,7 @@ Date: Jun 2023 KernelVersion: 6.5 Contact: "Luke Jones" Description: + DEPRECATED, WILL BE REMOVED SOON Set the dynamic boost limit of the Nvidia dGPU: * min=5, max=25 @@ -192,6 +206,7 @@ Date: Jun 2023 KernelVersion: 6.5 Contact: "Luke Jones" Description: + DEPRECATED, WILL BE REMOVED SOON Set the target temperature limit of the Nvidia dGPU: * min=75, max=87 @@ -200,6 +215,7 @@ Date: Apr 2024 KernelVersion: 6.10 Contact: "Luke Jones" Description: + DEPRECATED, WILL BE REMOVED SOON Set if the BIOS POST sound is played on boot. * 0 - False, * 1 - True @@ -209,6 +225,7 @@ Date: Apr 2024 KernelVersion: 6.10 Contact: "Luke Jones" Description: + DEPRECATED, WILL BE REMOVED SOON Set if the MCU can go in to low-power mode on system sleep * 0 - False, * 1 - True diff --git a/drivers/gpu/drm/amd/display/dc/dml2/dml2_translation_helper.c b/drivers/gpu/drm/amd/display/dc/dml2/dml2_translation_helper.c index b8a34abaf519..ac4692d1b54f 100644 --- a/drivers/gpu/drm/amd/display/dc/dml2/dml2_translation_helper.c +++ b/drivers/gpu/drm/amd/display/dc/dml2/dml2_translation_helper.c @@ -892,7 +892,7 @@ static void populate_dummy_dml_surface_cfg(struct dml_surface_cfg_st *out, unsig out->SurfaceWidthC[location] = in->timing.h_addressable; out->SurfaceHeightC[location] = in->timing.v_addressable; out->PitchY[location] = ((out->SurfaceWidthY[location] + 127) / 128) * 128; - out->PitchC[location] = 0; + out->PitchC[location] = 1; out->DCCEnable[location] = false; out->DCCMetaPitchY[location] = 0; out->DCCMetaPitchC[location] = 0; diff --git a/drivers/hid/Kconfig b/drivers/hid/Kconfig index 4cfea399ebab..d979b18f7f5b 100644 --- a/drivers/hid/Kconfig +++ b/drivers/hid/Kconfig @@ -164,6 +164,15 @@ config HID_ASUS - GL553V series - GL753V series +config HID_ASUS_ALLY + tristate "Asus Ally gamepad configuration support" + depends on USB_HID + depends on LEDS_CLASS + depends on LEDS_CLASS_MULTICOLOR + select POWER_SUPPLY + help + Support for configuring the Asus ROG Ally gamepad using attributes. + config HID_AUREAL tristate "Aureal" help diff --git a/drivers/hid/Makefile b/drivers/hid/Makefile index c7ecfbb3e228..733ab7cc5813 100644 --- a/drivers/hid/Makefile +++ b/drivers/hid/Makefile @@ -31,6 +31,7 @@ obj-$(CONFIG_HID_APPLE) += hid-apple.o obj-$(CONFIG_HID_APPLEIR) += hid-appleir.o obj-$(CONFIG_HID_CREATIVE_SB0540) += hid-creative-sb0540.o obj-$(CONFIG_HID_ASUS) += hid-asus.o +obj-$(CONFIG_HID_ASUS_ALLY) += hid-asus-ally.o obj-$(CONFIG_HID_AUREAL) += hid-aureal.o obj-$(CONFIG_HID_BELKIN) += hid-belkin.o obj-$(CONFIG_HID_BETOP_FF) += hid-betopff.o diff --git a/drivers/hid/hid-asus-ally.c b/drivers/hid/hid-asus-ally.c new file mode 100644 index 000000000000..e78625f70c44 --- /dev/null +++ b/drivers/hid/hid-asus-ally.c @@ -0,0 +1,2197 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * HID driver for Asus ROG laptops and Ally + * + * Copyright (c) 2023 Luke Jones + */ + +#include "linux/compiler_attributes.h" +#include "linux/device.h" +#include +#include +#include "linux/pm.h" +#include "linux/printk.h" +#include "linux/slab.h" +#include +#include +#include +#include +#include + +#include "hid-ids.h" +#include "hid-asus.h" +#include "hid-asus-ally.h" + +#define DEBUG + +#define READY_MAX_TRIES 3 +#define FEATURE_REPORT_ID 0x0d +#define FEATURE_ROG_ALLY_REPORT_ID 0x5a +#define FEATURE_ROG_ALLY_CODE_PAGE 0xD1 +#define FEATURE_ROG_ALLY_REPORT_SIZE 64 +#define ALLY_X_INPUT_REPORT_USB 0x0B +#define ALLY_X_INPUT_REPORT_USB_SIZE 16 + +#define ROG_ALLY_REPORT_SIZE 64 +#define ROG_ALLY_X_MIN_MCU 313 +#define ROG_ALLY_MIN_MCU 319 + +#define FEATURE_KBD_LED_REPORT_ID1 0x5d +#define FEATURE_KBD_LED_REPORT_ID2 0x5e + +#define BTN_DATA_LEN 11; +#define BTN_CODE_BYTES_LEN 8 + +static const u8 EC_INIT_STRING[] = { 0x5A, 'A', 'S', 'U', 'S', ' ', 'T', 'e','c', 'h', '.', 'I', 'n', 'c', '.', '\0' }; +static const u8 EC_MODE_LED_APPLY[] = { 0x5A, 0xB4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; +static const u8 EC_MODE_LED_SET[] = { 0x5A, 0xB5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; +static const u8 FORCE_FEEDBACK_OFF[] = { 0x0D, 0x0F, 0x00, 0x00, 0x00, 0x00, 0xFF, 0x00, 0xEB }; + +static const struct hid_device_id rog_ally_devices[] = { + { HID_USB_DEVICE(USB_VENDOR_ID_ASUSTEK, USB_DEVICE_ID_ASUSTEK_ROG_NKEY_ALLY) }, + { HID_USB_DEVICE(USB_VENDOR_ID_ASUSTEK, USB_DEVICE_ID_ASUSTEK_ROG_NKEY_ALLY_X) }, + {} +}; + +struct btn_code_map { + u64 code; + const char *name; +}; + +static const struct btn_code_map ally_btn_codes[] = { + { 0, "NONE" }, + /* Gamepad button codes */ + { BTN_PAD_A, "PAD_A" }, + { BTN_PAD_B, "PAD_B" }, + { BTN_PAD_X, "PAD_X" }, + { BTN_PAD_Y, "PAD_Y" }, + { BTN_PAD_LB, "PAD_LB" }, + { BTN_PAD_RB, "PAD_RB" }, + { BTN_PAD_LS, "PAD_LS" }, + { BTN_PAD_RS, "PAD_RS" }, + { BTN_PAD_DPAD_UP, "PAD_DPAD_UP" }, + { BTN_PAD_DPAD_DOWN, "PAD_DPAD_DOWN" }, + { BTN_PAD_DPAD_LEFT, "PAD_DPAD_LEFT" }, + { BTN_PAD_DPAD_RIGHT, "PAD_DPAD_RIGHT" }, + { BTN_PAD_VIEW, "PAD_VIEW" }, + { BTN_PAD_MENU, "PAD_MENU" }, + { BTN_PAD_XBOX, "PAD_XBOX" }, + + /* Triggers mapped to keyboard codes */ + { BTN_KB_M2, "KB_M2" }, + { BTN_KB_M1, "KB_M1" }, + { BTN_KB_ESC, "KB_ESC" }, + { BTN_KB_F1, "KB_F1" }, + { BTN_KB_F2, "KB_F2" }, + { BTN_KB_F3, "KB_F3" }, + { BTN_KB_F4, "KB_F4" }, + { BTN_KB_F5, "KB_F5" }, + { BTN_KB_F6, "KB_F6" }, + { BTN_KB_F7, "KB_F7" }, + { BTN_KB_F8, "KB_F8" }, + { BTN_KB_F9, "KB_F9" }, + { BTN_KB_F10, "KB_F10" }, + { BTN_KB_F11, "KB_F11" }, + { BTN_KB_F12, "KB_F12" }, + { BTN_KB_F14, "KB_F14" }, + { BTN_KB_F15, "KB_F15" }, + { BTN_KB_BACKTICK, "KB_BACKTICK" }, + { BTN_KB_1, "KB_1" }, + { BTN_KB_2, "KB_2" }, + { BTN_KB_3, "KB_3" }, + { BTN_KB_4, "KB_4" }, + { BTN_KB_5, "KB_5" }, + { BTN_KB_6, "KB_6" }, + { BTN_KB_7, "KB_7" }, + { BTN_KB_8, "KB_8" }, + { BTN_KB_9, "KB_9" }, + { BTN_KB_0, "KB_0" }, + { BTN_KB_HYPHEN, "KB_HYPHEN" }, + { BTN_KB_EQUALS, "KB_EQUALS" }, + { BTN_KB_BACKSPACE, "KB_BACKSPACE" }, + { BTN_KB_TAB, "KB_TAB" }, + { BTN_KB_Q, "KB_Q" }, + { BTN_KB_W, "KB_W" }, + { BTN_KB_E, "KB_E" }, + { BTN_KB_R, "KB_R" }, + { BTN_KB_T, "KB_T" }, + { BTN_KB_Y, "KB_Y" }, + { BTN_KB_U, "KB_U" }, + { BTN_KB_O, "KB_O" }, + { BTN_KB_P, "KB_P" }, + { BTN_KB_LBRACKET, "KB_LBRACKET" }, + { BTN_KB_RBRACKET, "KB_RBRACKET" }, + { BTN_KB_BACKSLASH, "KB_BACKSLASH" }, + { BTN_KB_CAPS, "KB_CAPS" }, + { BTN_KB_A, "KB_A" }, + { BTN_KB_S, "KB_S" }, + { BTN_KB_D, "KB_D" }, + { BTN_KB_F, "KB_F" }, + { BTN_KB_G, "KB_G" }, + { BTN_KB_H, "KB_H" }, + { BTN_KB_J, "KB_J" }, + { BTN_KB_K, "KB_K" }, + { BTN_KB_L, "KB_L" }, + { BTN_KB_SEMI, "KB_SEMI" }, + { BTN_KB_QUOTE, "KB_QUOTE" }, + { BTN_KB_RET, "KB_RET" }, + { BTN_KB_LSHIFT, "KB_LSHIFT" }, + { BTN_KB_Z, "KB_Z" }, + { BTN_KB_X, "KB_X" }, + { BTN_KB_C, "KB_C" }, + { BTN_KB_V, "KB_V" }, + { BTN_KB_B, "KB_B" }, + { BTN_KB_N, "KB_N" }, + { BTN_KB_M, "KB_M" }, + { BTN_KB_COMMA, "KB_COMMA" }, + { BTN_KB_PERIOD, "KB_PERIOD" }, + { BTN_KB_RSHIFT, "KB_RSHIFT" }, + { BTN_KB_LCTL, "KB_LCTL" }, + { BTN_KB_META, "KB_META" }, + { BTN_KB_LALT, "KB_LALT" }, + { BTN_KB_SPACE, "KB_SPACE" }, + { BTN_KB_RALT, "KB_RALT" }, + { BTN_KB_MENU, "KB_MENU" }, + { BTN_KB_RCTL, "KB_RCTL" }, + { BTN_KB_PRNTSCN, "KB_PRNTSCN" }, + { BTN_KB_SCRLCK, "KB_SCRLCK" }, + { BTN_KB_PAUSE, "KB_PAUSE" }, + { BTN_KB_INS, "KB_INS" }, + { BTN_KB_HOME, "KB_HOME" }, + { BTN_KB_PGUP, "KB_PGUP" }, + { BTN_KB_DEL, "KB_DEL" }, + { BTN_KB_END, "KB_END" }, + { BTN_KB_PGDWN, "KB_PGDWN" }, + { BTN_KB_UP_ARROW, "KB_UP_ARROW" }, + { BTN_KB_DOWN_ARROW, "KB_DOWN_ARROW" }, + { BTN_KB_LEFT_ARROW, "KB_LEFT_ARROW" }, + { BTN_KB_RIGHT_ARROW, "KB_RIGHT_ARROW" }, + + /* Numpad mappings */ + { BTN_NUMPAD_LOCK, "NUMPAD_LOCK" }, + { BTN_NUMPAD_FWDSLASH, "NUMPAD_FWDSLASH" }, + { BTN_NUMPAD_ASTERISK, "NUMPAD_ASTERISK" }, + { BTN_NUMPAD_HYPHEN, "NUMPAD_HYPHEN" }, + { BTN_NUMPAD_0, "NUMPAD_0" }, + { BTN_NUMPAD_1, "NUMPAD_1" }, + { BTN_NUMPAD_2, "NUMPAD_2" }, + { BTN_NUMPAD_3, "NUMPAD_3" }, + { BTN_NUMPAD_4, "NUMPAD_4" }, + { BTN_NUMPAD_5, "NUMPAD_5" }, + { BTN_NUMPAD_6, "NUMPAD_6" }, + { BTN_NUMPAD_7, "NUMPAD_7" }, + { BTN_NUMPAD_8, "NUMPAD_8" }, + { BTN_NUMPAD_9, "NUMPAD_9" }, + { BTN_NUMPAD_PLUS, "NUMPAD_PLUS" }, + { BTN_NUMPAD_ENTER, "NUMPAD_ENTER" }, + { BTN_NUMPAD_PERIOD, "NUMPAD_PERIOD" }, + + /* Mouse mappings */ + { BTN_MOUSE_LCLICK, "MOUSE_LCLICK" }, + { BTN_MOUSE_RCLICK, "MOUSE_RCLICK" }, + { BTN_MOUSE_MCLICK, "MOUSE_MCLICK" }, + { BTN_MOUSE_WHEEL_UP, "MOUSE_WHEEL_UP" }, + { BTN_MOUSE_WHEEL_DOWN, "MOUSE_WHEEL_DOWN" }, + + /* Media mappings */ + { BTN_MEDIA_SCREENSHOT, "MEDIA_SCREENSHOT" }, + { BTN_MEDIA_SHOW_KEYBOARD, "MEDIA_SHOW_KEYBOARD" }, + { BTN_MEDIA_SHOW_DESKTOP, "MEDIA_SHOW_DESKTOP" }, + { BTN_MEDIA_START_RECORDING, "MEDIA_START_RECORDING" }, + { BTN_MEDIA_MIC_OFF, "MEDIA_MIC_OFF" }, + { BTN_MEDIA_VOL_DOWN, "MEDIA_VOL_DOWN" }, + { BTN_MEDIA_VOL_UP, "MEDIA_VOL_UP" }, +}; +static const size_t keymap_len = ARRAY_SIZE(ally_btn_codes); + +/* byte_array must be >= 8 in length */ +static void btn_code_to_byte_array(u64 keycode, u8 *byte_array) +{ + /* Convert the u64 to bytes[8] */ + for (int i = 0; i < 8; ++i) { + byte_array[i] = (keycode >> (56 - 8 * i)) & 0xFF; + } +} + +static u64 name_to_btn(const char *name) +{ + int len = strcspn(name, "\n"); + for (size_t i = 0; i < keymap_len; ++i) { + if (strncmp(ally_btn_codes[i].name, name, len) == 0) { + return ally_btn_codes[i].code; + } + } + return -EINVAL; +} + +static const char* btn_to_name(u64 key) +{ + for (size_t i = 0; i < keymap_len; ++i) { + if (ally_btn_codes[i].code == key) { + return ally_btn_codes[i].name; + } + } + return NULL; +} + +struct btn_data { + u64 button; + u64 macro; + bool turbo; +}; + +struct btn_mapping { + struct btn_data btn_a; + struct btn_data btn_b; + struct btn_data btn_x; + struct btn_data btn_y; + struct btn_data btn_lb; + struct btn_data btn_rb; + struct btn_data btn_ls; + struct btn_data btn_rs; + struct btn_data btn_lt; + struct btn_data btn_rt; + struct btn_data dpad_up; + struct btn_data dpad_down; + struct btn_data dpad_left; + struct btn_data dpad_right; + struct btn_data btn_view; + struct btn_data btn_menu; + struct btn_data btn_m1; + struct btn_data btn_m2; +}; + +struct deadzone { + u8 inner; + u8 outer; +}; + +struct response_curve { + uint8_t move_pct_1; + uint8_t response_pct_1; + uint8_t move_pct_2; + uint8_t response_pct_2; + uint8_t move_pct_3; + uint8_t response_pct_3; + uint8_t move_pct_4; + uint8_t response_pct_4; +} __packed; + +struct js_axis_calibrations { + uint16_t left_y_stable; + uint16_t left_y_min; + uint16_t left_y_max; + uint16_t left_x_stable; + uint16_t left_x_min; + uint16_t left_x_max; + uint16_t right_y_stable; + uint16_t right_y_min; + uint16_t right_y_max; + uint16_t right_x_stable; + uint16_t right_x_min; + uint16_t right_x_max; +} __packed; + +struct tr_axis_calibrations { + uint16_t left_stable; + uint16_t left_max; + uint16_t right_stable; + uint16_t right_max; +} __packed; + +/* ROG Ally has many settings related to the gamepad, all using the same n-key endpoint */ +struct ally_gamepad_cfg { + struct hid_device *hdev; + struct input_dev *input; + + enum xpad_mode mode; + /* + * index: [mode] + */ + struct btn_mapping key_mapping[xpad_mode_mouse]; + /* + * index: left, right + * max: 64 + */ + u8 vibration_intensity[2]; + + /* deadzones */ + struct deadzone ls_dz; // left stick + struct deadzone rs_dz; // right stick + struct deadzone lt_dz; // left trigger + struct deadzone rt_dz; // right trigger + /* anti-deadzones */ + u8 ls_adz; // left stick + u8 rs_adz; // right stick + /* joystick response curves */ + struct response_curve ls_rc; + struct response_curve rs_rc; + + struct js_axis_calibrations js_cal; + struct tr_axis_calibrations tr_cal; +}; + +/* The hatswitch outputs integers, we use them to index this X|Y pair */ +static const int hat_values[][2] = { + { 0, 0 }, { 0, -1 }, { 1, -1 }, { 1, 0 }, { 1, 1 }, + { 0, 1 }, { -1, 1 }, { -1, 0 }, { -1, -1 }, +}; + +/* rumble packet structure */ +struct ff_data { + u8 enable; + u8 magnitude_left; + u8 magnitude_right; + u8 magnitude_strong; + u8 magnitude_weak; + u8 pulse_sustain_10ms; + u8 pulse_release_10ms; + u8 loop_count; +} __packed; + +struct ff_report { + u8 report_id; + struct ff_data ff; +} __packed; + +struct ally_x_input_report { + uint16_t x, y; + uint16_t rx, ry; + uint16_t z, rz; + uint8_t buttons[4]; +} __packed; + +struct ally_x_device { + struct input_dev *input; + struct hid_device *hdev; + spinlock_t lock; + + struct ff_report *ff_packet; + struct work_struct output_worker; + bool output_worker_initialized; + /* Prevent multiple queued event due to the enforced delay in worker */ + bool update_qam_btn; + /* Set if the QAM and AC buttons emit Xbox and Xbox+A */ + bool qam_btns_steam_mode; + bool update_ff; +}; + +struct ally_rgb_dev { + struct hid_device *hdev; + struct led_classdev_mc led_rgb_dev; + struct work_struct work; + bool output_worker_initialized; + spinlock_t lock; + + bool removed; + bool update_rgb; + uint8_t red[4]; + uint8_t green[4]; + uint8_t blue[4]; +}; + +struct ally_rgb_data { + uint8_t brightness; + uint8_t red[4]; + uint8_t green[4]; + uint8_t blue[4]; + bool initialized; +}; + +static struct ally_drvdata { + struct hid_device *hdev; + struct ally_x_device *ally_x; + struct ally_gamepad_cfg *gamepad_cfg; + struct ally_rgb_dev *led_rgb_dev; + struct ally_rgb_data led_rgb_data; + uint mcu_version; +} drvdata; + +static void reverse_bytes_in_pairs(u8 *buf, size_t size) { + uint16_t *word_ptr; + size_t i; + + for (i = 0; i < size; i += 2) { + if (i + 1 < size) { + word_ptr = (uint16_t *)&buf[i]; + *word_ptr = cpu_to_be16(*word_ptr); + } + } +} + +/** + * asus_dev_set_report - send set report request to device. + * + * @hdev: hid device + * @buf: in/out data to transfer + * @len: length of buf + * + * Return: count of data transferred, negative if error + * + * Same behavior as hid_hw_raw_request. Note that the input buffer is duplicated. + */ +static int asus_dev_set_report(struct hid_device *hdev, const u8 *buf, size_t len) +{ + unsigned char *dmabuf; + int ret; + + dmabuf = kmemdup(buf, len, GFP_KERNEL); + if (!dmabuf) + return -ENOMEM; + + ret = hid_hw_raw_request(hdev, buf[0], dmabuf, len, HID_FEATURE_REPORT, + HID_REQ_SET_REPORT); + kfree(dmabuf); + + return ret; +} + +/** + * asus_dev_get_report - send get report request to device. + * + * @hdev: hid device + * @out: buffer to write output data in to + * @len: length the output buffer provided + * + * Return: count of data transferred, negative if error + * + * Same behavior as hid_hw_raw_request. + */ +static int asus_dev_get_report(struct hid_device *hdev, u8 *out, size_t len) +{ + return hid_hw_raw_request(hdev, FEATURE_REPORT_ID, out, len, + HID_FEATURE_REPORT, HID_REQ_GET_REPORT); +} + +static u8 get_endpoint_address(struct hid_device *hdev) +{ + struct usb_interface *intf; + struct usb_host_endpoint *ep; + + intf = to_usb_interface(hdev->dev.parent); + + if (intf) { + ep = intf->cur_altsetting->endpoint; + if (ep) { + return ep->desc.bEndpointAddress; + } + } + + return -ENODEV; +} + +/**************************************************************************************************/ +/* ROG Ally gamepad configuration */ +/**************************************************************************************************/ + +/* This should be called before any attempts to set device functions */ +static int ally_gamepad_check_ready(struct hid_device *hdev) +{ + int ret, count; + u8 *hidbuf; + + hidbuf = kzalloc(FEATURE_ROG_ALLY_REPORT_SIZE, GFP_KERNEL); + if (!hidbuf) + return -ENOMEM; + + ret = 0; + for (count = 0; count < READY_MAX_TRIES; count++) { + hidbuf[0] = FEATURE_ROG_ALLY_REPORT_ID; + hidbuf[1] = FEATURE_ROG_ALLY_CODE_PAGE; + hidbuf[2] = xpad_cmd_check_ready; + hidbuf[3] = 01; + ret = asus_dev_set_report(hdev, hidbuf, FEATURE_ROG_ALLY_REPORT_SIZE); + if (ret < 0) + hid_dbg(hdev, "ROG Ally check failed set report: %d\n", ret); + + hidbuf[0] = hidbuf[1] = hidbuf[2] = hidbuf[3] = 0; + ret = asus_dev_get_report(hdev, hidbuf, FEATURE_ROG_ALLY_REPORT_SIZE); + if (ret < 0) + hid_dbg(hdev, "ROG Ally check failed get report: %d\n", ret); + + ret = hidbuf[2] == xpad_cmd_check_ready; + if (ret) + break; + usleep_range( + 1000, + 2000); /* don't spam the entire loop in less than USB response time */ + } + + if (count == READY_MAX_TRIES) + hid_warn(hdev, "ROG Ally never responded with a ready\n"); + + kfree(hidbuf); + return ret; +} + +/* VIBRATION INTENSITY ****************************************************************************/ +static ssize_t gamepad_vibration_intensity_index_show(struct device *dev, + struct device_attribute *attr, char *buf) +{ + return sysfs_emit(buf, "left right\n"); +} + +ALLY_DEVICE_ATTR_RO(gamepad_vibration_intensity_index, vibration_intensity_index); + +static ssize_t _gamepad_apply_intensity(struct hid_device *hdev, + struct ally_gamepad_cfg *ally_cfg) +{ + u8 *hidbuf; + int ret; + + hidbuf = kzalloc(FEATURE_ROG_ALLY_REPORT_SIZE, GFP_KERNEL); + if (!hidbuf) + return -ENOMEM; + + hidbuf[0] = FEATURE_ROG_ALLY_REPORT_ID; + hidbuf[1] = FEATURE_ROG_ALLY_CODE_PAGE; + hidbuf[2] = xpad_cmd_set_vibe_intensity; + hidbuf[3] = xpad_cmd_len_vibe_intensity; + hidbuf[4] = ally_cfg->vibration_intensity[0]; + hidbuf[5] = ally_cfg->vibration_intensity[1]; + + ret = ally_gamepad_check_ready(hdev); + if (ret < 0) + goto report_fail; + + ret = asus_dev_set_report(hdev, hidbuf, FEATURE_ROG_ALLY_REPORT_SIZE); + if (ret < 0) + goto report_fail; + +report_fail: + kfree(hidbuf); + return ret; +} + +static ssize_t gamepad_vibration_intensity_show(struct device *dev, + struct device_attribute *attr, char *buf) +{ + struct ally_gamepad_cfg *ally_cfg = drvdata.gamepad_cfg; + + if (!drvdata.gamepad_cfg) + return -ENODEV; + + return sysfs_emit( + buf, "%d %d\n", + ally_cfg->vibration_intensity[0], + ally_cfg->vibration_intensity[1]); +} + +static ssize_t gamepad_vibration_intensity_store(struct device *dev, + struct device_attribute *attr, const char *buf, + size_t count) +{ + struct hid_device *hdev = to_hid_device(dev); + struct ally_gamepad_cfg *ally_cfg = drvdata.gamepad_cfg; + u32 left, right; + int ret; + + if (!drvdata.gamepad_cfg) + return -ENODEV; + + if (sscanf(buf, "%d %d", &left, &right) != 2) + return -EINVAL; + + if (left > 64 || right > 64) + return -EINVAL; + + ally_cfg->vibration_intensity[0] = left; + ally_cfg->vibration_intensity[1] = right; + + ret = _gamepad_apply_intensity(hdev, ally_cfg); + if (ret < 0) + return ret; + + return count; +} + +ALLY_DEVICE_ATTR_RW(gamepad_vibration_intensity, vibration_intensity); + +/* ANALOGUE DEADZONES *****************************************************************************/ +static ssize_t _gamepad_apply_deadzones(struct hid_device *hdev, + struct ally_gamepad_cfg *ally_cfg) +{ + u8 *hidbuf; + int ret; + + ret = ally_gamepad_check_ready(hdev); + if (ret < 0) + return ret; + + hidbuf = kzalloc(FEATURE_ROG_ALLY_REPORT_SIZE, GFP_KERNEL); + if (!hidbuf) + return -ENOMEM; + + hidbuf[0] = FEATURE_ROG_ALLY_REPORT_ID; + hidbuf[1] = FEATURE_ROG_ALLY_CODE_PAGE; + hidbuf[2] = xpad_cmd_set_js_dz; + hidbuf[3] = xpad_cmd_len_deadzone; + hidbuf[4] = ally_cfg->ls_dz.inner; + hidbuf[5] = ally_cfg->ls_dz.outer; + hidbuf[6] = ally_cfg->rs_dz.inner; + hidbuf[7] = ally_cfg->rs_dz.outer; + + ret = asus_dev_set_report(hdev, hidbuf, FEATURE_ROG_ALLY_REPORT_SIZE); + if (ret < 0) + goto end; + + hidbuf[2] = xpad_cmd_set_tr_dz; + hidbuf[4] = ally_cfg->lt_dz.inner; + hidbuf[5] = ally_cfg->lt_dz.outer; + hidbuf[6] = ally_cfg->rt_dz.inner; + hidbuf[7] = ally_cfg->rt_dz.outer; + + ret = asus_dev_set_report(hdev, hidbuf, FEATURE_ROG_ALLY_REPORT_SIZE); + if (ret < 0) + goto end; + +end: + kfree(hidbuf); + return ret; +} + +static void _gamepad_set_deadzones_default(struct ally_gamepad_cfg *ally_cfg) +{ + ally_cfg->ls_dz.inner = 0x00; + ally_cfg->ls_dz.outer = 0x64; + ally_cfg->rs_dz.inner = 0x00; + ally_cfg->rs_dz.outer = 0x64; + ally_cfg->lt_dz.inner = 0x00; + ally_cfg->lt_dz.outer = 0x64; + ally_cfg->rt_dz.inner = 0x00; + ally_cfg->rt_dz.outer = 0x64; +} + +static ssize_t axis_xyz_deadzone_index_show(struct device *dev, struct device_attribute *attr, + char *buf) +{ + return sysfs_emit(buf, "inner outer\n"); +} + +ALLY_DEVICE_ATTR_RO(axis_xyz_deadzone_index, deadzone_index); + +ALLY_DEADZONES(axis_xy_left, ls_dz); +ALLY_DEADZONES(axis_xy_right, rs_dz); +ALLY_DEADZONES(axis_z_left, lt_dz); +ALLY_DEADZONES(axis_z_right, rt_dz); + +/* ANTI-DEADZONES *********************************************************************************/ +static ssize_t _gamepad_apply_js_ADZ(struct hid_device *hdev, + struct ally_gamepad_cfg *ally_cfg) +{ + u8 *hidbuf; + int ret; + + hidbuf = kzalloc(FEATURE_ROG_ALLY_REPORT_SIZE, GFP_KERNEL); + if (!hidbuf) + return -ENOMEM; + + hidbuf[0] = FEATURE_ROG_ALLY_REPORT_ID; + hidbuf[1] = FEATURE_ROG_ALLY_CODE_PAGE; + hidbuf[2] = xpad_cmd_set_adz; + hidbuf[3] = xpad_cmd_len_adz; + hidbuf[4] = ally_cfg->ls_adz; + hidbuf[5] = ally_cfg->rs_adz; + + ret = ally_gamepad_check_ready(hdev); + if (ret < 0) + goto report_fail; + + ret = asus_dev_set_report(hdev, hidbuf, FEATURE_ROG_ALLY_REPORT_SIZE); + if (ret < 0) + goto report_fail; + +report_fail: + kfree(hidbuf); + return ret; +} + +static void _gamepad_set_anti_deadzones_default(struct ally_gamepad_cfg *ally_cfg) +{ + ally_cfg->ls_adz = 0x00; + ally_cfg->rs_adz = 0x00; +} + +static ssize_t _gamepad_js_ADZ_store(struct device *dev, const char *buf, u8 *adz) +{ + int ret, val; + + ret = kstrtoint(buf, 0, &val); + if (ret) + return ret; + + if (val < 0 || val > 32) + return -EINVAL; + + *adz = val; + + return ret; +} + +static ssize_t axis_xy_left_anti_deadzone_show(struct device *dev, + struct device_attribute *attr, + char *buf) +{ + struct ally_gamepad_cfg *ally_cfg = drvdata.gamepad_cfg; + + return sysfs_emit(buf, "%d\n", ally_cfg->ls_adz); +} + +static ssize_t axis_xy_left_anti_deadzone_store(struct device *dev, + struct device_attribute *attr, + const char *buf, size_t count) +{ + struct ally_gamepad_cfg *ally_cfg = drvdata.gamepad_cfg; + int ret; + + ret = _gamepad_js_ADZ_store(dev, buf, &ally_cfg->ls_adz); + if (ret) + return ret; + + return count; +} +ALLY_DEVICE_ATTR_RW(axis_xy_left_anti_deadzone, anti_deadzone); + +static ssize_t axis_xy_right_anti_deadzone_show(struct device *dev, + struct device_attribute *attr, + char *buf) +{ + struct ally_gamepad_cfg *ally_cfg = drvdata.gamepad_cfg; + + return sysfs_emit(buf, "%d\n", ally_cfg->rs_adz); +} + +static ssize_t axis_xy_right_anti_deadzone_store(struct device *dev, + struct device_attribute *attr, + const char *buf, size_t count) +{ + struct ally_gamepad_cfg *ally_cfg = drvdata.gamepad_cfg; + int ret; + + ret = _gamepad_js_ADZ_store(dev, buf, &ally_cfg->rs_adz); + if (ret) + return ret; + + return count; +} +ALLY_DEVICE_ATTR_RW(axis_xy_right_anti_deadzone, anti_deadzone); + +/* JS RESPONSE CURVES *****************************************************************************/ +static void _gamepad_set_js_response_curves_default(struct ally_gamepad_cfg *ally_cfg) +{ + struct response_curve *js1_rc = &ally_cfg->ls_rc; + struct response_curve *js2_rc = &ally_cfg->rs_rc; + js1_rc->move_pct_1 = js2_rc->move_pct_1 = 0x16; // 25% + js1_rc->move_pct_2 = js2_rc->move_pct_2 = 0x32; // 50% + js1_rc->move_pct_3 = js2_rc->move_pct_3 = 0x48; // 75% + js1_rc->move_pct_4 = js2_rc->move_pct_4 = 0x64; // 100% + js1_rc->response_pct_1 = js2_rc->response_pct_1 = 0x16; + js1_rc->response_pct_2 = js2_rc->response_pct_2 = 0x32; + js1_rc->response_pct_3 = js2_rc->response_pct_3 = 0x48; + js1_rc->response_pct_4 = js2_rc->response_pct_4 = 0x64; +} + +static ssize_t _gamepad_apply_response_curves(struct hid_device *hdev, + struct ally_gamepad_cfg *ally_cfg) +{ + u8 *hidbuf; + int ret; + + hidbuf = kzalloc(FEATURE_ROG_ALLY_REPORT_SIZE, GFP_KERNEL); + if (!hidbuf) + return -ENOMEM; + + hidbuf[0] = FEATURE_ROG_ALLY_REPORT_ID; + hidbuf[1] = FEATURE_ROG_ALLY_CODE_PAGE; + memcpy(&hidbuf[2], &ally_cfg->ls_rc, sizeof(ally_cfg->ls_rc)); + + ret = ally_gamepad_check_ready(hdev); + if (ret < 0) + goto report_fail; + + hidbuf[4] = 0x02; + memcpy(&hidbuf[5], &ally_cfg->rs_rc, sizeof(ally_cfg->rs_rc)); + + ret = ally_gamepad_check_ready(hdev); + if (ret < 0) + goto report_fail; + + ret = asus_dev_set_report(hdev, hidbuf, FEATURE_ROG_ALLY_REPORT_SIZE); + if (ret < 0) + goto report_fail; + +report_fail: + kfree(hidbuf); + return ret; +} + +ALLY_JS_RC_POINT(axis_xy_left, move, 1); +ALLY_JS_RC_POINT(axis_xy_left, move, 2); +ALLY_JS_RC_POINT(axis_xy_left, move, 3); +ALLY_JS_RC_POINT(axis_xy_left, move, 4); +ALLY_JS_RC_POINT(axis_xy_left, response, 1); +ALLY_JS_RC_POINT(axis_xy_left, response, 2); +ALLY_JS_RC_POINT(axis_xy_left, response, 3); +ALLY_JS_RC_POINT(axis_xy_left, response, 4); + +ALLY_JS_RC_POINT(axis_xy_right, move, 1); +ALLY_JS_RC_POINT(axis_xy_right, move, 2); +ALLY_JS_RC_POINT(axis_xy_right, move, 3); +ALLY_JS_RC_POINT(axis_xy_right, move, 4); +ALLY_JS_RC_POINT(axis_xy_right, response, 1); +ALLY_JS_RC_POINT(axis_xy_right, response, 2); +ALLY_JS_RC_POINT(axis_xy_right, response, 3); +ALLY_JS_RC_POINT(axis_xy_right, response, 4); + +/* CALIBRATIONS ***********************************************************************************/ +static int gamepad_get_calibration(struct hid_device *hdev) +{ + struct ally_gamepad_cfg *ally_cfg = drvdata.gamepad_cfg; + u8 *hidbuf; + int ret, i; + + if (!drvdata.gamepad_cfg) + return -ENODEV; + + hidbuf = kzalloc(FEATURE_ROG_ALLY_REPORT_SIZE, GFP_KERNEL); + if (!hidbuf) + return -ENOMEM; + + for (i = 0; i < 2; i++) { + hidbuf[0] = FEATURE_ROG_ALLY_REPORT_ID; + hidbuf[1] = 0xD0; + hidbuf[2] = 0x03; + hidbuf[3] = i + 1; // 0x01 JS, 0x02 TR + hidbuf[4] = 0x20; + + ret = asus_dev_set_report(hdev, hidbuf, FEATURE_ROG_ALLY_REPORT_SIZE); + if (ret < 0) { + hid_warn(hdev, "ROG Ally check failed set report: %d\n", ret); + goto cleanup; + } + + memset(hidbuf, 0, FEATURE_ROG_ALLY_REPORT_SIZE); + ret = asus_dev_get_report(hdev, hidbuf, FEATURE_ROG_ALLY_REPORT_SIZE); + if (ret < 0 || hidbuf[5] != 1) { + hid_warn(hdev, "ROG Ally check failed get report: %d\n", ret); + goto cleanup; + } + + if (i == 0) { + /* Joystick calibration */ + reverse_bytes_in_pairs(&hidbuf[6], sizeof(struct js_axis_calibrations)); + ally_cfg->js_cal = *(struct js_axis_calibrations *)&hidbuf[6]; + print_hex_dump(KERN_INFO, "HID Buffer JS: ", DUMP_PREFIX_OFFSET, 16, 1, hidbuf, 32, true); + struct js_axis_calibrations *cal = &drvdata.gamepad_cfg->js_cal; + pr_err("LS_CAL: X: %d, Min: %d, Max: %d", cal->left_x_stable, cal->left_x_min, cal->left_x_max); + pr_err("LS_CAL: Y: %d, Min: %d, Max: %d", cal->left_y_stable, cal->left_y_min, cal->left_y_max); + pr_err("RS_CAL: X: %d, Min: %d, Max: %d", cal->right_x_stable, cal->right_x_min, cal->right_x_max); + pr_err("RS_CAL: Y: %d, Min: %d, Max: %d", cal->right_y_stable, cal->right_y_min, cal->right_y_max); + } else { + /* Trigger calibration */ + reverse_bytes_in_pairs(&hidbuf[6], sizeof(struct tr_axis_calibrations)); + ally_cfg->tr_cal = *(struct tr_axis_calibrations *)&hidbuf[6]; + print_hex_dump(KERN_INFO, "HID Buffer TR: ", DUMP_PREFIX_OFFSET, 16, 1, hidbuf, 32, true); + } + } + +cleanup: + kfree(hidbuf); + return ret; +} + +static struct attribute *axis_xy_left_attrs[] = { + &dev_attr_axis_xy_left_anti_deadzone.attr, + &dev_attr_axis_xy_left_deadzone.attr, + &dev_attr_axis_xyz_deadzone_index.attr, + &dev_attr_axis_xy_left_move_1.attr, + &dev_attr_axis_xy_left_move_2.attr, + &dev_attr_axis_xy_left_move_3.attr, + &dev_attr_axis_xy_left_move_4.attr, + &dev_attr_axis_xy_left_response_1.attr, + &dev_attr_axis_xy_left_response_2.attr, + &dev_attr_axis_xy_left_response_3.attr, + &dev_attr_axis_xy_left_response_4.attr, + NULL +}; +static const struct attribute_group axis_xy_left_attr_group = { + .name = "axis_xy_left", + .attrs = axis_xy_left_attrs, +}; + +static struct attribute *axis_xy_right_attrs[] = { + &dev_attr_axis_xy_right_anti_deadzone.attr, + &dev_attr_axis_xy_right_deadzone.attr, + &dev_attr_axis_xyz_deadzone_index.attr, + &dev_attr_axis_xy_right_move_1.attr, + &dev_attr_axis_xy_right_move_2.attr, + &dev_attr_axis_xy_right_move_3.attr, + &dev_attr_axis_xy_right_move_4.attr, + &dev_attr_axis_xy_right_response_1.attr, + &dev_attr_axis_xy_right_response_2.attr, + &dev_attr_axis_xy_right_response_3.attr, + &dev_attr_axis_xy_right_response_4.attr, + NULL +}; +static const struct attribute_group axis_xy_right_attr_group = { + .name = "axis_xy_right", + .attrs = axis_xy_right_attrs, +}; + +static struct attribute *axis_z_left_attrs[] = { + &dev_attr_axis_z_left_deadzone.attr, + &dev_attr_axis_xyz_deadzone_index.attr, + NULL, +}; +static const struct attribute_group axis_z_left_attr_group = { + .name = "axis_z_left", + .attrs = axis_z_left_attrs, +}; + +static struct attribute *axis_z_right_attrs[] = { + &dev_attr_axis_z_right_deadzone.attr, + &dev_attr_axis_xyz_deadzone_index.attr, + NULL, +}; +static const struct attribute_group axis_z_right_attr_group = { + .name = "axis_z_right", + .attrs = axis_z_right_attrs, +}; + +/* A HID packet conatins mappings for two buttons: btn1, btn1_macro, btn2, btn2_macro */ +static void _btn_pair_to_hid_pkt(struct ally_gamepad_cfg *ally_cfg, + enum btn_pair_index pair, + struct btn_data *btn1, struct btn_data *btn2, + u8 *out, int out_len) +{ + int start = 5; + + out[0] = FEATURE_ROG_ALLY_REPORT_ID; + out[1] = FEATURE_ROG_ALLY_CODE_PAGE; + out[2] = xpad_cmd_set_mapping; + out[3] = pair; + out[4] = xpad_cmd_len_mapping; + + btn_code_to_byte_array(btn1->button, &out[start]); + start += BTN_DATA_LEN; + btn_code_to_byte_array(btn1->macro, &out[start]); + start += BTN_DATA_LEN; + btn_code_to_byte_array(btn2->button, &out[start]); + start += BTN_DATA_LEN; + btn_code_to_byte_array(btn2->macro, &out[start]); + //print_hex_dump(KERN_DEBUG, "byte_array: ", DUMP_PREFIX_OFFSET, 64, 1, out, 64, false); +} + +/* Apply the mapping pair to the device */ +static int _gamepad_apply_btn_pair(struct hid_device *hdev, struct ally_gamepad_cfg *ally_cfg, + enum btn_pair_index btn_pair) +{ + u8 mode = ally_cfg->mode - 1; + struct btn_data *btn1, *btn2; + u8 *hidbuf; + int ret; + + ret = ally_gamepad_check_ready(hdev); + if (ret < 0) + return ret; + + hidbuf = kzalloc(FEATURE_ROG_ALLY_REPORT_SIZE, GFP_KERNEL); + if (!hidbuf) + return -ENOMEM; + + switch (btn_pair) { + case btn_pair_dpad_u_d: + btn1 = &ally_cfg->key_mapping[mode].dpad_up; + btn2 = &ally_cfg->key_mapping[mode].dpad_down; + break; + case btn_pair_dpad_l_r: + btn1 = &ally_cfg->key_mapping[mode].dpad_left; + btn2 = &ally_cfg->key_mapping[mode].dpad_right; + break; + case btn_pair_ls_rs: + btn1 = &ally_cfg->key_mapping[mode].btn_ls; + btn2 = &ally_cfg->key_mapping[mode].btn_rs; + break; + case btn_pair_lb_rb: + btn1 = &ally_cfg->key_mapping[mode].btn_lb; + btn2 = &ally_cfg->key_mapping[mode].btn_rb; + break; + case btn_pair_lt_rt: + btn1 = &ally_cfg->key_mapping[mode].btn_lt; + btn2 = &ally_cfg->key_mapping[mode].btn_rt; + break; + case btn_pair_a_b: + btn1 = &ally_cfg->key_mapping[mode].btn_a; + btn2 = &ally_cfg->key_mapping[mode].btn_b; + break; + case btn_pair_x_y: + btn1 = &ally_cfg->key_mapping[mode].btn_x; + btn2 = &ally_cfg->key_mapping[mode].btn_y; + break; + case btn_pair_view_menu: + btn1 = &ally_cfg->key_mapping[mode].btn_view; + btn2 = &ally_cfg->key_mapping[mode].btn_menu; + break; + case btn_pair_m1_m2: + btn1 = &ally_cfg->key_mapping[mode].btn_m1; + btn2 = &ally_cfg->key_mapping[mode].btn_m2; + break; + default: + break; + } + + _btn_pair_to_hid_pkt(ally_cfg, btn_pair, btn1, btn2, hidbuf, FEATURE_ROG_ALLY_REPORT_SIZE); + ret = asus_dev_set_report(hdev, hidbuf, FEATURE_ROG_ALLY_REPORT_SIZE); + + kfree(hidbuf); + + return ret; +} + +static int _gamepad_apply_turbo(struct hid_device *hdev, struct ally_gamepad_cfg *ally_cfg) +{ + struct btn_mapping *map = &ally_cfg->key_mapping[ally_cfg->mode - 1]; + u8 *hidbuf; + int ret; + + /* set turbo */ + hidbuf = kzalloc(FEATURE_ROG_ALLY_REPORT_SIZE, GFP_KERNEL); + if (!hidbuf) + return -ENOMEM; + hidbuf[0] = FEATURE_ROG_ALLY_REPORT_ID; + hidbuf[1] = FEATURE_ROG_ALLY_CODE_PAGE; + hidbuf[2] = xpad_cmd_set_turbo; + hidbuf[3] = xpad_cmd_len_turbo; + + hidbuf[4] = map->dpad_up.turbo; + hidbuf[6] = map->dpad_down.turbo; + hidbuf[8] = map->dpad_left.turbo; + hidbuf[10] = map->dpad_right.turbo; + + hidbuf[12] = map->btn_ls.turbo; + hidbuf[14] = map->btn_rs.turbo; + hidbuf[16] = map->btn_lb.turbo; + hidbuf[18] = map->btn_rb.turbo; + + hidbuf[20] = map->btn_a.turbo; + hidbuf[22] = map->btn_b.turbo; + hidbuf[24] = map->btn_x.turbo; + hidbuf[26] = map->btn_y.turbo; + + hidbuf[28] = map->btn_lt.turbo; + hidbuf[30] = map->btn_rt.turbo; + + ret = asus_dev_set_report(hdev, hidbuf, FEATURE_ROG_ALLY_REPORT_SIZE); + + kfree(hidbuf); + + return ret; +} + +static ssize_t _gamepad_apply_all(struct hid_device *hdev, struct ally_gamepad_cfg *ally_cfg) +{ + int ret; + + ret = _gamepad_apply_btn_pair(hdev, ally_cfg, btn_pair_dpad_u_d); + if (ret < 0) + return ret; + ret = _gamepad_apply_btn_pair(hdev, ally_cfg, btn_pair_dpad_l_r); + if (ret < 0) + return ret; + ret = _gamepad_apply_btn_pair(hdev, ally_cfg, btn_pair_ls_rs); + if (ret < 0) + return ret; + ret = _gamepad_apply_btn_pair(hdev, ally_cfg, btn_pair_lb_rb); + if (ret < 0) + return ret; + ret = _gamepad_apply_btn_pair(hdev, ally_cfg, btn_pair_a_b); + if (ret < 0) + return ret; + ret = _gamepad_apply_btn_pair(hdev, ally_cfg, btn_pair_x_y); + if (ret < 0) + return ret; + ret = _gamepad_apply_btn_pair(hdev, ally_cfg, btn_pair_view_menu); + if (ret < 0) + return ret; + ret = _gamepad_apply_btn_pair(hdev, ally_cfg, btn_pair_m1_m2); + if (ret < 0) + return ret; + ret = _gamepad_apply_btn_pair(hdev, ally_cfg, btn_pair_lt_rt); + if (ret < 0) + return ret; + ret = _gamepad_apply_turbo(hdev, ally_cfg); + if (ret < 0) + return ret; + ret = _gamepad_apply_deadzones(hdev, ally_cfg); + if (ret < 0) + return ret; + ret = _gamepad_apply_js_ADZ(hdev, ally_cfg); + if (ret < 0) + return ret; + ret =_gamepad_apply_response_curves(hdev, ally_cfg); + if (ret < 0) + return ret; + + return 0; +} + +static ssize_t gamepad_apply_all_store(struct device *dev, struct device_attribute *attr, + const char *buf, size_t count) +{ + struct ally_gamepad_cfg *ally_cfg = drvdata.gamepad_cfg; + struct hid_device *hdev = to_hid_device(dev); + int ret; + + if (!drvdata.gamepad_cfg) + return -ENODEV; + + ret = _gamepad_apply_all(hdev, ally_cfg); + if (ret < 0) + return ret; + + return count; +} +ALLY_DEVICE_ATTR_WO(gamepad_apply_all, apply_all); + +/* button map attributes, regular and macro*/ +ALLY_BTN_MAPPING(m1, btn_m1); +ALLY_BTN_MAPPING(m2, btn_m2); +ALLY_BTN_MAPPING(view, btn_view); +ALLY_BTN_MAPPING(menu, btn_menu); +ALLY_TURBO_BTN_MAPPING(a, btn_a); +ALLY_TURBO_BTN_MAPPING(b, btn_b); +ALLY_TURBO_BTN_MAPPING(x, btn_x); +ALLY_TURBO_BTN_MAPPING(y, btn_y); +ALLY_TURBO_BTN_MAPPING(lb, btn_lb); +ALLY_TURBO_BTN_MAPPING(rb, btn_rb); +ALLY_TURBO_BTN_MAPPING(ls, btn_ls); +ALLY_TURBO_BTN_MAPPING(rs, btn_rs); +ALLY_TURBO_BTN_MAPPING(lt, btn_lt); +ALLY_TURBO_BTN_MAPPING(rt, btn_rt); +ALLY_TURBO_BTN_MAPPING(dpad_u, dpad_up); +ALLY_TURBO_BTN_MAPPING(dpad_d, dpad_down); +ALLY_TURBO_BTN_MAPPING(dpad_l, dpad_left); +ALLY_TURBO_BTN_MAPPING(dpad_r, dpad_right); + +static void _gamepad_set_xpad_default(struct ally_gamepad_cfg *ally_cfg) +{ + struct btn_mapping *map = &ally_cfg->key_mapping[ally_cfg->mode - 1]; + map->btn_m1.button = BTN_KB_M1; + map->btn_m2.button = BTN_KB_M2; + map->btn_a.button = BTN_PAD_A; + map->btn_b.button = BTN_PAD_B; + map->btn_x.button = BTN_PAD_X; + map->btn_y.button = BTN_PAD_Y; + map->btn_lb.button = BTN_PAD_LB; + map->btn_rb.button = BTN_PAD_RB; + map->btn_lt.button = BTN_PAD_LT; + map->btn_rt.button = BTN_PAD_RT; + map->btn_ls.button = BTN_PAD_LS; + map->btn_rs.button = BTN_PAD_RS; + map->dpad_up.button = BTN_PAD_DPAD_UP; + map->dpad_down.button = BTN_PAD_DPAD_DOWN; + map->dpad_left.button = BTN_PAD_DPAD_LEFT; + map->dpad_right.button = BTN_PAD_DPAD_RIGHT; + map->btn_view.button = BTN_PAD_VIEW; + map->btn_menu.button = BTN_PAD_MENU; +} + +static ssize_t btn_mapping_reset_store(struct device *dev, struct device_attribute *attr, + const char *buf, size_t count) +{ + struct ally_gamepad_cfg *ally_cfg = drvdata.gamepad_cfg; + + if (!drvdata.gamepad_cfg) + return -ENODEV; + + switch (ally_cfg->mode) { + case xpad_mode_game: + _gamepad_set_xpad_default(ally_cfg); + break; + default: + _gamepad_set_xpad_default(ally_cfg); + break; + } + + return count; +} +ALLY_DEVICE_ATTR_WO(btn_mapping_reset, reset_btn_mapping); + +/* GAMEPAD MODE */ +static ssize_t _gamepad_set_mode(struct hid_device *hdev, struct ally_gamepad_cfg *ally_cfg, + int val) +{ + u8 *hidbuf; + int ret; + + hidbuf = kzalloc(FEATURE_ROG_ALLY_REPORT_SIZE, GFP_KERNEL); + if (!hidbuf) + return -ENOMEM; + + hidbuf[0] = FEATURE_ROG_ALLY_REPORT_ID; + hidbuf[1] = FEATURE_ROG_ALLY_CODE_PAGE; + hidbuf[2] = xpad_cmd_set_mode; + hidbuf[3] = xpad_cmd_len_mode; + hidbuf[4] = val; + + ret = ally_gamepad_check_ready(hdev); + if (ret < 0) + goto report_fail; + + ret = asus_dev_set_report(hdev, hidbuf, FEATURE_ROG_ALLY_REPORT_SIZE); + if (ret < 0) + goto report_fail; + + ret = _gamepad_apply_all(hdev, ally_cfg); + if (ret < 0) + goto report_fail; + +report_fail: + kfree(hidbuf); + return ret; +} + +static ssize_t gamepad_mode_show(struct device *dev, struct device_attribute *attr, char *buf) +{ + struct ally_gamepad_cfg *ally_cfg = drvdata.gamepad_cfg; + + if (!drvdata.gamepad_cfg) + return -ENODEV; + + return sysfs_emit(buf, "%d\n", ally_cfg->mode); +} + +static ssize_t gamepad_mode_store(struct device *dev, struct device_attribute *attr, + const char *buf, size_t count) +{ + struct hid_device *hdev = to_hid_device(dev); + struct ally_gamepad_cfg *ally_cfg = drvdata.gamepad_cfg; + int ret, val; + + if (!drvdata.gamepad_cfg) + return -ENODEV; + + ret = kstrtoint(buf, 0, &val); + if (ret) + return ret; + + if (val < xpad_mode_game || val > xpad_mode_mouse) + return -EINVAL; + + ally_cfg->mode = val; + + ret = _gamepad_set_mode(hdev, ally_cfg, val); + if (ret < 0) + return ret; + + return count; +} + +DEVICE_ATTR_RW(gamepad_mode); + +static ssize_t mcu_version_show(struct device *dev, struct device_attribute *attr, char *buf) +{ + return sysfs_emit(buf, "%d\n", drvdata.mcu_version); +} + +DEVICE_ATTR_RO(mcu_version); + +/* ROOT LEVEL ATTRS *******************************************************************************/ +static struct attribute *gamepad_device_attrs[] = { + &dev_attr_btn_mapping_reset.attr, + &dev_attr_gamepad_mode.attr, + &dev_attr_gamepad_apply_all.attr, + &dev_attr_gamepad_vibration_intensity.attr, + &dev_attr_gamepad_vibration_intensity_index.attr, + &dev_attr_mcu_version.attr, + NULL +}; + +static const struct attribute_group ally_controller_attr_group = { + .attrs = gamepad_device_attrs, +}; + +static const struct attribute_group *gamepad_device_attr_groups[] = { + &ally_controller_attr_group, + &axis_xy_left_attr_group, + &axis_xy_right_attr_group, + &axis_z_left_attr_group, + &axis_z_right_attr_group, + &btn_mapping_m1_attr_group, + &btn_mapping_m2_attr_group, + &btn_mapping_a_attr_group, + &btn_mapping_b_attr_group, + &btn_mapping_x_attr_group, + &btn_mapping_y_attr_group, + &btn_mapping_lb_attr_group, + &btn_mapping_rb_attr_group, + &btn_mapping_ls_attr_group, + &btn_mapping_rs_attr_group, + &btn_mapping_lt_attr_group, + &btn_mapping_rt_attr_group, + &btn_mapping_dpad_u_attr_group, + &btn_mapping_dpad_d_attr_group, + &btn_mapping_dpad_l_attr_group, + &btn_mapping_dpad_r_attr_group, + &btn_mapping_view_attr_group, + &btn_mapping_menu_attr_group, + NULL, +}; + +static struct ally_gamepad_cfg *ally_gamepad_cfg_create(struct hid_device *hdev) +{ + struct ally_gamepad_cfg *ally_cfg; + struct input_dev *input_dev; + int err; + + ally_cfg = devm_kzalloc(&hdev->dev, sizeof(*ally_cfg), GFP_KERNEL); + if (!ally_cfg) + return ERR_PTR(-ENOMEM); + ally_cfg->hdev = hdev; + // Allocate memory for each mode's `btn_mapping` + ally_cfg->mode = xpad_mode_game; + + input_dev = devm_input_allocate_device(&hdev->dev); + if (!input_dev) { + err = -ENOMEM; + goto free_ally_cfg; + } + + input_dev->id.bustype = hdev->bus; + input_dev->id.vendor = hdev->vendor; + input_dev->id.product = hdev->product; + input_dev->id.version = hdev->version; + input_dev->uniq = hdev->uniq; + input_dev->name = "ASUS ROG Ally Config"; + input_set_capability(input_dev, EV_KEY, KEY_PROG1); + input_set_capability(input_dev, EV_KEY, KEY_F16); + input_set_capability(input_dev, EV_KEY, KEY_F17); + input_set_capability(input_dev, EV_KEY, KEY_F18); + input_set_drvdata(input_dev, hdev); + + err = input_register_device(input_dev); + if (err) + goto free_input_dev; + ally_cfg->input = input_dev; + + /* ignore all errors for this as they are related to USB HID I/O */ + _gamepad_set_xpad_default(ally_cfg); + ally_cfg->key_mapping[ally_cfg->mode - 1].btn_m1.button = BTN_KB_M1; + ally_cfg->key_mapping[ally_cfg->mode - 1].btn_m2.button = BTN_KB_M2; + _gamepad_apply_btn_pair(hdev, ally_cfg, btn_pair_m1_m2); + gamepad_get_calibration(hdev); + + ally_cfg->vibration_intensity[0] = 0x64; + ally_cfg->vibration_intensity[1] = 0x64; + _gamepad_set_deadzones_default(ally_cfg); + _gamepad_set_anti_deadzones_default(ally_cfg); + _gamepad_set_js_response_curves_default(ally_cfg); + + drvdata.gamepad_cfg = ally_cfg; // Must asign before attr group setup + if (sysfs_create_groups(&hdev->dev.kobj, gamepad_device_attr_groups)) { + err = -ENODEV; + goto unregister_input_dev; + } + + return ally_cfg; + +unregister_input_dev: + input_unregister_device(input_dev); + ally_cfg->input = NULL; // Prevent double free when kfree(ally_cfg) happens + +free_input_dev: + devm_kfree(&hdev->dev, input_dev); + +free_ally_cfg: + devm_kfree(&hdev->dev, ally_cfg); + return ERR_PTR(err); +} + +static void ally_cfg_remove(struct hid_device *hdev) +{ + // __gamepad_set_mode(hdev, drvdata.gamepad_cfg, xpad_mode_mouse); + sysfs_remove_groups(&hdev->dev.kobj, gamepad_device_attr_groups); +} + +/**************************************************************************************************/ +/* ROG Ally gamepad i/o and force-feedback */ +/**************************************************************************************************/ +static int ally_x_raw_event(struct ally_x_device *ally_x, struct hid_report *report, u8 *data, + int size) +{ + struct ally_x_input_report *in_report; + unsigned long flags; + u8 byte; + + if (data[0] == 0x0B) { + in_report = (struct ally_x_input_report *)&data[1]; + + input_report_abs(ally_x->input, ABS_X, in_report->x); + input_report_abs(ally_x->input, ABS_Y, in_report->y); + input_report_abs(ally_x->input, ABS_RX, in_report->rx); + input_report_abs(ally_x->input, ABS_RY, in_report->ry); + input_report_abs(ally_x->input, ABS_Z, in_report->z); + input_report_abs(ally_x->input, ABS_RZ, in_report->rz); + + byte = in_report->buttons[0]; + input_report_key(ally_x->input, BTN_A, byte & BIT(0)); + input_report_key(ally_x->input, BTN_B, byte & BIT(1)); + input_report_key(ally_x->input, BTN_X, byte & BIT(2)); + input_report_key(ally_x->input, BTN_Y, byte & BIT(3)); + input_report_key(ally_x->input, BTN_TL, byte & BIT(4)); + input_report_key(ally_x->input, BTN_TR, byte & BIT(5)); + input_report_key(ally_x->input, BTN_SELECT, byte & BIT(6)); + input_report_key(ally_x->input, BTN_START, byte & BIT(7)); + + byte = in_report->buttons[1]; + input_report_key(ally_x->input, BTN_THUMBL, byte & BIT(0)); + input_report_key(ally_x->input, BTN_THUMBR, byte & BIT(1)); + input_report_key(ally_x->input, BTN_MODE, byte & BIT(2)); + + byte = in_report->buttons[2]; + input_report_abs(ally_x->input, ABS_HAT0X, hat_values[byte][0]); + input_report_abs(ally_x->input, ABS_HAT0Y, hat_values[byte][1]); + } + /* + * The MCU used on Ally provides many devices: gamepad, keyboord, mouse, other. + * The AC and QAM buttons route through another interface making it difficult to + * use the events unless we grab those and use them here. Only works for Ally X. + */ + else if (data[0] == 0x5A) { + if (ally_x->qam_btns_steam_mode) { + spin_lock_irqsave(&ally_x->lock, flags); + if (data[1] == 0x38 && !ally_x->update_qam_btn) { + ally_x->update_qam_btn = true; + if (ally_x->output_worker_initialized) + schedule_work(&ally_x->output_worker); + } + spin_unlock_irqrestore(&ally_x->lock, flags); + /* Left/XBox button. Long press does ctrl+alt+del which we can't catch */ + input_report_key(ally_x->input, BTN_MODE, data[1] == 0xA6); + } else { + input_report_key(ally_x->input, KEY_F16, data[1] == 0xA6); + input_report_key(ally_x->input, KEY_PROG1, data[1] == 0x38); + } + /* QAM long press */ + input_report_key(ally_x->input, KEY_F17, data[1] == 0xA7); + /* QAM long press released */ + input_report_key(ally_x->input, KEY_F18, data[1] == 0xA8); + } + + input_sync(ally_x->input); + + return 0; +} + +static struct input_dev *ally_x_alloc_input_dev(struct hid_device *hdev, + const char *name_suffix) +{ + struct input_dev *input_dev; + + input_dev = devm_input_allocate_device(&hdev->dev); + if (!input_dev) + return ERR_PTR(-ENOMEM); + + input_dev->id.bustype = hdev->bus; + input_dev->id.vendor = hdev->vendor; + input_dev->id.product = hdev->product; + input_dev->id.version = hdev->version; + input_dev->uniq = hdev->uniq; + input_dev->name = "ASUS ROG Ally X Gamepad"; + + input_set_drvdata(input_dev, hdev); + + return input_dev; +} + +static int ally_x_play_effect(struct input_dev *idev, void *data, struct ff_effect *effect) +{ + struct ally_x_device *ally_x = drvdata.ally_x; + unsigned long flags; + + if (effect->type != FF_RUMBLE) + return 0; + + spin_lock_irqsave(&ally_x->lock, flags); + ally_x->ff_packet->ff.magnitude_strong = effect->u.rumble.strong_magnitude / 512; + ally_x->ff_packet->ff.magnitude_weak = effect->u.rumble.weak_magnitude / 512; + ally_x->update_ff = true; + spin_unlock_irqrestore(&ally_x->lock, flags); + + if (ally_x->output_worker_initialized) + schedule_work(&ally_x->output_worker); + + return 0; +} + +static void ally_x_work(struct work_struct *work) +{ + struct ally_x_device *ally_x = container_of(work, struct ally_x_device, output_worker); + struct ff_report *ff_report = NULL; + bool update_qam = false; + bool update_ff = false; + unsigned long flags; + + spin_lock_irqsave(&ally_x->lock, flags); + update_ff = ally_x->update_ff; + if (ally_x->update_ff) { + ff_report = kmemdup(ally_x->ff_packet, sizeof(*ally_x->ff_packet), GFP_KERNEL); + ally_x->update_ff = false; + } + update_qam = ally_x->update_qam_btn; + spin_unlock_irqrestore(&ally_x->lock, flags); + + if (update_ff && ff_report) { + ff_report->ff.magnitude_left = ff_report->ff.magnitude_strong; + ff_report->ff.magnitude_right = ff_report->ff.magnitude_weak; + asus_dev_set_report(ally_x->hdev, (u8 *)ff_report, sizeof(*ff_report)); + } + kfree(ff_report); + + if (update_qam) { + /* + * The sleeps here are required to allow steam to register the button combo. + */ + usleep_range(1000, 2000); + input_report_key(ally_x->input, BTN_MODE, 1); + input_sync(ally_x->input); + + msleep(80); + input_report_key(ally_x->input, BTN_A, 1); + input_sync(ally_x->input); + + msleep(80); + input_report_key(ally_x->input, BTN_A, 0); + input_sync(ally_x->input); + + msleep(80); + input_report_key(ally_x->input, BTN_MODE, 0); + input_sync(ally_x->input); + + spin_lock_irqsave(&ally_x->lock, flags); + ally_x->update_qam_btn = false; + spin_unlock_irqrestore(&ally_x->lock, flags); + } +} + +static struct input_dev *ally_x_setup_input(struct hid_device *hdev) +{ + int ret, abs_min = 0, js_abs_max = 65535, tr_abs_max = 1023; + struct input_dev *input; + + input = ally_x_alloc_input_dev(hdev, NULL); + if (IS_ERR(input)) + return ERR_CAST(input); + + input_set_abs_params(input, ABS_X, abs_min, js_abs_max, 0, 0); + input_set_abs_params(input, ABS_Y, abs_min, js_abs_max, 0, 0); + input_set_abs_params(input, ABS_RX, abs_min, js_abs_max, 0, 0); + input_set_abs_params(input, ABS_RY, abs_min, js_abs_max, 0, 0); + input_set_abs_params(input, ABS_Z, abs_min, tr_abs_max, 0, 0); + input_set_abs_params(input, ABS_RZ, abs_min, tr_abs_max, 0, 0); + input_set_abs_params(input, ABS_HAT0X, -1, 1, 0, 0); + input_set_abs_params(input, ABS_HAT0Y, -1, 1, 0, 0); + input_set_capability(input, EV_KEY, BTN_A); + input_set_capability(input, EV_KEY, BTN_B); + input_set_capability(input, EV_KEY, BTN_X); + input_set_capability(input, EV_KEY, BTN_Y); + input_set_capability(input, EV_KEY, BTN_TL); + input_set_capability(input, EV_KEY, BTN_TR); + input_set_capability(input, EV_KEY, BTN_SELECT); + input_set_capability(input, EV_KEY, BTN_START); + input_set_capability(input, EV_KEY, BTN_MODE); + input_set_capability(input, EV_KEY, BTN_THUMBL); + input_set_capability(input, EV_KEY, BTN_THUMBR); + + input_set_capability(input, EV_KEY, KEY_PROG1); + input_set_capability(input, EV_KEY, KEY_F16); + input_set_capability(input, EV_KEY, KEY_F17); + input_set_capability(input, EV_KEY, KEY_F18); + + input_set_capability(input, EV_FF, FF_RUMBLE); + input_ff_create_memless(input, NULL, ally_x_play_effect); + + ret = input_register_device(input); + if (ret) + return ERR_PTR(ret); + + return input; +} + +static ssize_t ally_x_qam_mode_show(struct device *dev, struct device_attribute *attr, + char *buf) +{ + struct ally_x_device *ally_x = drvdata.ally_x; + + return sysfs_emit(buf, "%d\n", ally_x->qam_btns_steam_mode); +} + +static ssize_t ally_x_qam_mode_store(struct device *dev, struct device_attribute *attr, + const char *buf, size_t count) +{ + struct ally_x_device *ally_x = drvdata.ally_x; + bool val; + int ret; + + ret = kstrtobool(buf, &val); + if (ret < 0) + return ret; + + ally_x->qam_btns_steam_mode = val; + + return count; +} +ALLY_DEVICE_ATTR_RW(ally_x_qam_mode, qam_mode); + +static struct ally_x_device *ally_x_create(struct hid_device *hdev) +{ + uint8_t max_output_report_size; + struct ally_x_device *ally_x; + struct ff_report *report; + int ret; + + ally_x = devm_kzalloc(&hdev->dev, sizeof(*ally_x), GFP_KERNEL); + if (!ally_x) + return ERR_PTR(-ENOMEM); + + ally_x->hdev = hdev; + INIT_WORK(&ally_x->output_worker, ally_x_work); + spin_lock_init(&ally_x->lock); + ally_x->output_worker_initialized = true; + ally_x->qam_btns_steam_mode = + true; /* Always default to steam mode, it can be changed by userspace attr */ + + max_output_report_size = sizeof(struct ally_x_input_report); + report = devm_kzalloc(&hdev->dev, sizeof(*report), GFP_KERNEL); + if (!report) { + ret = -ENOMEM; + goto free_ally_x; + } + + /* None of these bytes will change for the FF command for now */ + report->report_id = 0x0D; + report->ff.enable = 0x0F; /* Enable all by default */ + report->ff.pulse_sustain_10ms = 0xFF; /* Duration */ + report->ff.pulse_release_10ms = 0x00; /* Start Delay */ + report->ff.loop_count = 0xEB; /* Loop Count */ + ally_x->ff_packet = report; + + ally_x->input = ally_x_setup_input(hdev); + if (IS_ERR(ally_x->input)) { + ret = PTR_ERR(ally_x->input); + goto free_ff_packet; + } + + if (sysfs_create_file(&hdev->dev.kobj, &dev_attr_ally_x_qam_mode.attr)) { + ret = -ENODEV; + goto unregister_input; + } + + ally_x->update_ff = true; + if (ally_x->output_worker_initialized) + schedule_work(&ally_x->output_worker); + + hid_info(hdev, "Registered Ally X controller using %s\n", + dev_name(&ally_x->input->dev)); + return ally_x; + +unregister_input: + input_unregister_device(ally_x->input); +free_ff_packet: + kfree(ally_x->ff_packet); +free_ally_x: + kfree(ally_x); + return ERR_PTR(ret); +} + +static void ally_x_remove(struct hid_device *hdev) +{ + struct ally_x_device *ally_x = drvdata.ally_x; + unsigned long flags; + + spin_lock_irqsave(&ally_x->lock, flags); + ally_x->output_worker_initialized = false; + spin_unlock_irqrestore(&ally_x->lock, flags); + cancel_work_sync(&ally_x->output_worker); + sysfs_remove_file(&hdev->dev.kobj, &dev_attr_ally_x_qam_mode.attr); +} + +/**************************************************************************************************/ +/* ROG Ally LED control */ +/**************************************************************************************************/ +static void ally_rgb_schedule_work(struct ally_rgb_dev *led) +{ + unsigned long flags; + + spin_lock_irqsave(&led->lock, flags); + if (!led->removed) + schedule_work(&led->work); + spin_unlock_irqrestore(&led->lock, flags); +} + +/* + * The RGB still has the basic 0-3 level brightness. Since the multicolour + * brightness is being used in place, set this to max + */ +static int ally_rgb_set_bright_base_max(struct hid_device *hdev) +{ + u8 buf[] = { FEATURE_KBD_LED_REPORT_ID1, 0xba, 0xc5, 0xc4, 0x02 }; + + return asus_dev_set_report(hdev, buf, sizeof(buf)); +} + +static void ally_rgb_do_work(struct work_struct *work) +{ + struct ally_rgb_dev *led = container_of(work, struct ally_rgb_dev, work); + int ret; + unsigned long flags; + + u8 buf[16] = { [0] = FEATURE_ROG_ALLY_REPORT_ID, + [1] = FEATURE_ROG_ALLY_CODE_PAGE, + [2] = xpad_cmd_set_leds, + [3] = xpad_cmd_len_leds }; + + spin_lock_irqsave(&led->lock, flags); + if (!led->update_rgb) { + spin_unlock_irqrestore(&led->lock, flags); + return; + } + + for (int i = 0; i < 4; i++) { + buf[5 + i * 3] = drvdata.led_rgb_dev->green[i]; + buf[6 + i * 3] = drvdata.led_rgb_dev->blue[i]; + buf[4 + i * 3] = drvdata.led_rgb_dev->red[i]; + } + led->update_rgb = false; + + spin_unlock_irqrestore(&led->lock, flags); + + ret = asus_dev_set_report(led->hdev, buf, sizeof(buf)); + if (ret < 0) + hid_err(led->hdev, "Ally failed to set gamepad backlight: %d\n", ret); +} + +static void ally_rgb_set(struct led_classdev *cdev, enum led_brightness brightness) +{ + struct led_classdev_mc *mc_cdev = lcdev_to_mccdev(cdev); + struct ally_rgb_dev *led = container_of(mc_cdev, struct ally_rgb_dev, led_rgb_dev); + int intensity, bright; + unsigned long flags; + + led_mc_calc_color_components(mc_cdev, brightness); + spin_lock_irqsave(&led->lock, flags); + led->update_rgb = true; + bright = mc_cdev->led_cdev.brightness; + for (int i = 0; i < 4; i++) { + intensity = mc_cdev->subled_info[i].intensity; + drvdata.led_rgb_dev->red[i] = (((intensity >> 16) & 0xFF) * bright) / 255; + drvdata.led_rgb_dev->green[i] = (((intensity >> 8) & 0xFF) * bright) / 255; + drvdata.led_rgb_dev->blue[i] = ((intensity & 0xFF) * bright) / 255; + } + spin_unlock_irqrestore(&led->lock, flags); + drvdata.led_rgb_data.initialized = true; + + ally_rgb_schedule_work(led); +} + +static int ally_rgb_set_static_from_multi(struct hid_device *hdev) +{ + u8 buf[17] = {FEATURE_KBD_LED_REPORT_ID1, 0xb3}; + int ret; + + /* + * Set single zone single colour based on the first LED of EC software mode. + * buf[2] = zone, buf[3] = mode + */ + buf[4] = drvdata.led_rgb_data.red[0]; + buf[5] = drvdata.led_rgb_data.green[0]; + buf[6] = drvdata.led_rgb_data.blue[0]; + + ret = asus_dev_set_report(hdev, buf, sizeof(buf)); + if (ret < 0) + return ret; + + ret = asus_dev_set_report(hdev, EC_MODE_LED_APPLY, sizeof(EC_MODE_LED_APPLY)); + if (ret < 0) + return ret; + + return asus_dev_set_report(hdev, EC_MODE_LED_SET, sizeof(EC_MODE_LED_SET)); +} + +/* + * Store the RGB values for restoring on resume, and set the static mode to the first LED colour +*/ +static void ally_rgb_store_settings(void) +{ + int arr_size = sizeof(drvdata.led_rgb_data.red); + + struct ally_rgb_dev *led_rgb = drvdata.led_rgb_dev; + + drvdata.led_rgb_data.brightness = led_rgb->led_rgb_dev.led_cdev.brightness; + + memcpy(drvdata.led_rgb_data.red, led_rgb->red, arr_size); + memcpy(drvdata.led_rgb_data.green, led_rgb->green, arr_size); + memcpy(drvdata.led_rgb_data.blue, led_rgb->blue, arr_size); + + ally_rgb_set_static_from_multi(led_rgb->hdev); +} + +static void ally_rgb_restore_settings(struct ally_rgb_dev *led_rgb, struct led_classdev *led_cdev, + struct mc_subled *mc_led_info) +{ + int arr_size = sizeof(drvdata.led_rgb_data.red); + + memcpy(led_rgb->red, drvdata.led_rgb_data.red, arr_size); + memcpy(led_rgb->green, drvdata.led_rgb_data.green, arr_size); + memcpy(led_rgb->blue, drvdata.led_rgb_data.blue, arr_size); + for (int i = 0; i < 4; i++) { + mc_led_info[i].intensity = (drvdata.led_rgb_data.red[i] << 16) | + (drvdata.led_rgb_data.green[i] << 8) | + drvdata.led_rgb_data.blue[i]; + } + led_cdev->brightness = drvdata.led_rgb_data.brightness; +} + +/* Set LEDs. Call after any setup. */ +static void ally_rgb_resume(void) +{ + struct ally_rgb_dev *led_rgb = drvdata.led_rgb_dev; + struct led_classdev *led_cdev; + struct mc_subled *mc_led_info; + + if (!led_rgb) + return; + + led_cdev = &led_rgb->led_rgb_dev.led_cdev; + mc_led_info = led_rgb->led_rgb_dev.subled_info; + + if (drvdata.led_rgb_data.initialized) { + ally_rgb_restore_settings(led_rgb, led_cdev, mc_led_info); + led_rgb->update_rgb = true; + ally_rgb_schedule_work(led_rgb); + ally_rgb_set_bright_base_max(led_rgb->hdev); + } +} + +static int ally_rgb_register(struct hid_device *hdev, struct ally_rgb_dev *led_rgb) +{ + struct mc_subled *mc_led_info; + struct led_classdev *led_cdev; + + mc_led_info = + devm_kmalloc_array(&hdev->dev, 12, sizeof(*mc_led_info), GFP_KERNEL | __GFP_ZERO); + if (!mc_led_info) + return -ENOMEM; + + mc_led_info[0].color_index = LED_COLOR_ID_RGB; + mc_led_info[1].color_index = LED_COLOR_ID_RGB; + mc_led_info[2].color_index = LED_COLOR_ID_RGB; + mc_led_info[3].color_index = LED_COLOR_ID_RGB; + + led_rgb->led_rgb_dev.subled_info = mc_led_info; + led_rgb->led_rgb_dev.num_colors = 4; + + led_cdev = &led_rgb->led_rgb_dev.led_cdev; + led_cdev->brightness = 128; + led_cdev->name = "ally:rgb:joystick_rings"; + led_cdev->max_brightness = 255; + led_cdev->brightness_set = ally_rgb_set; + + if (drvdata.led_rgb_data.initialized) { + ally_rgb_restore_settings(led_rgb, led_cdev, mc_led_info); + } + + return devm_led_classdev_multicolor_register(&hdev->dev, &led_rgb->led_rgb_dev); +} + +static struct ally_rgb_dev *ally_rgb_create(struct hid_device *hdev) +{ + struct ally_rgb_dev *led_rgb; + int ret; + + led_rgb = devm_kzalloc(&hdev->dev, sizeof(struct ally_rgb_dev), GFP_KERNEL); + if (!led_rgb) + return ERR_PTR(-ENOMEM); + + ret = ally_rgb_register(hdev, led_rgb); + if (ret < 0) { + cancel_work_sync(&led_rgb->work); + devm_kfree(&hdev->dev, led_rgb); + return ERR_PTR(ret); + } + + led_rgb->hdev = hdev; + led_rgb->removed = false; + + INIT_WORK(&led_rgb->work, ally_rgb_do_work); + led_rgb->output_worker_initialized = true; + spin_lock_init(&led_rgb->lock); + + ally_rgb_set_bright_base_max(hdev); + + /* Not marked as initialized unless ally_rgb_set() is called */ + if (drvdata.led_rgb_data.initialized) { + msleep(1500); + led_rgb->update_rgb = true; + ally_rgb_schedule_work(led_rgb); + } + + return led_rgb; +} + +static void ally_rgb_remove(struct hid_device *hdev) +{ + struct ally_rgb_dev *led_rgb = drvdata.led_rgb_dev; + unsigned long flags; + int ep; + + ep = get_endpoint_address(hdev); + if (ep != ROG_ALLY_CFG_INTF_IN) + return; + + if (!drvdata.led_rgb_dev || led_rgb->removed) + return; + + spin_lock_irqsave(&led_rgb->lock, flags); + led_rgb->removed = true; + led_rgb->output_worker_initialized = false; + spin_unlock_irqrestore(&led_rgb->lock, flags); + cancel_work_sync(&led_rgb->work); + devm_led_classdev_multicolor_unregister(&hdev->dev, &led_rgb->led_rgb_dev); + + hid_info(hdev, "Removed Ally RGB interface"); +} + +/**************************************************************************************************/ +/* ROG Ally driver init */ +/**************************************************************************************************/ + +static int ally_raw_event(struct hid_device *hdev, struct hid_report *report, u8 *data, + int size) +{ + struct ally_gamepad_cfg *cfg = drvdata.gamepad_cfg; + struct ally_x_device *ally_x = drvdata.ally_x; + + if (ally_x) { + if ((hdev->bus == BUS_USB && report->id == ALLY_X_INPUT_REPORT_USB && + size == ALLY_X_INPUT_REPORT_USB_SIZE) || + (data[0] == 0x5A)) { + ally_x_raw_event(ally_x, report, data, size); + } else { + return -1; + } + } + + if (cfg && !ally_x) { + input_report_key(cfg->input, KEY_PROG1, data[1] == 0x38); + input_report_key(cfg->input, KEY_F16, data[1] == 0xA6); + input_report_key(cfg->input, KEY_F17, data[1] == 0xA7); + input_report_key(cfg->input, KEY_F18, data[1] == 0xA8); + input_sync(cfg->input); + } + + return 0; +} + +static int ally_hid_init(struct hid_device *hdev) +{ + int ret; + + ret = asus_dev_set_report(hdev, EC_INIT_STRING, sizeof(EC_INIT_STRING)); + if (ret < 0) { + hid_err(hdev, "Ally failed to send init command: %d\n", ret); + return ret; + } + + ret = asus_dev_set_report(hdev, FORCE_FEEDBACK_OFF, sizeof(FORCE_FEEDBACK_OFF)); + if (ret < 0) + hid_err(hdev, "Ally failed to send init command: %d\n", ret); + + return ret; +} + +static int ally_hid_probe(struct hid_device *hdev, const struct hid_device_id *_id) +{ + struct usb_interface *intf = to_usb_interface(hdev->dev.parent); + struct usb_device *udev = interface_to_usbdev(intf); + u16 idProduct = le16_to_cpu(udev->descriptor.idProduct); + int ret, ep; + + ep = get_endpoint_address(hdev); + if (ep < 0) + return ep; + + if (ep != ROG_ALLY_CFG_INTF_IN && + ep != ROG_ALLY_X_INTF_IN) + return -ENODEV; + + ret = hid_parse(hdev); + if (ret) { + hid_err(hdev, "Parse failed\n"); + return ret; + } + + ret = hid_hw_start(hdev, HID_CONNECT_HIDRAW); + if (ret) { + hid_err(hdev, "Failed to start HID device\n"); + return ret; + } + + ret = hid_hw_open(hdev); + if (ret) { + hid_err(hdev, "Failed to open HID device\n"); + goto err_stop; + } + + /* Initialize MCU even before alloc */ + ret = ally_hid_init(hdev); + if (ret < 0) + return ret; + + drvdata.hdev = hdev; + hid_set_drvdata(hdev, &drvdata); + + /* This should almost always exist */ + if (ep == ROG_ALLY_CFG_INTF_IN) { + validate_mcu_fw_version(hdev, idProduct); + + drvdata.led_rgb_dev = ally_rgb_create(hdev); + if (IS_ERR(drvdata.led_rgb_dev)) + hid_err(hdev, "Failed to create Ally gamepad LEDs.\n"); + else + hid_info(hdev, "Created Ally RGB LED controls.\n"); + + drvdata.gamepad_cfg = ally_gamepad_cfg_create(hdev); + if (IS_ERR(drvdata.gamepad_cfg)) + hid_err(hdev, "Failed to create Ally gamepad attributes.\n"); + else + hid_info(hdev, "Created Ally gamepad attributes.\n"); + + if (IS_ERR(drvdata.led_rgb_dev) && IS_ERR(drvdata.gamepad_cfg)) + goto err_close; + } + + /* May or may not exist */ + if (ep == ROG_ALLY_X_INTF_IN) { + drvdata.ally_x = ally_x_create(hdev); + if (IS_ERR(drvdata.ally_x)) { + hid_err(hdev, "Failed to create Ally X gamepad.\n"); + drvdata.ally_x = NULL; + goto err_close; + } + hid_info(hdev, "Created Ally X controller.\n"); + + // Not required since we send this inputs ep through the gamepad input dev + if (drvdata.gamepad_cfg && drvdata.gamepad_cfg->input) { + input_unregister_device(drvdata.gamepad_cfg->input); + hid_info(hdev, "Ally X removed unrequired input dev.\n"); + } + } + + return 0; + +err_close: + hid_hw_close(hdev); +err_stop: + hid_hw_stop(hdev); + return ret; +} + +static void ally_hid_remove(struct hid_device *hdev) +{ + if (drvdata.led_rgb_dev) + ally_rgb_remove(hdev); + + if (drvdata.ally_x) + ally_x_remove(hdev); + + if (drvdata.gamepad_cfg) + ally_cfg_remove(hdev); + + hid_hw_close(hdev); + hid_hw_stop(hdev); +} + +static int ally_hid_resume(struct hid_device *hdev) +{ + struct ally_gamepad_cfg *ally_cfg = drvdata.gamepad_cfg; + int err; + + if (!ally_cfg) + return 0; + + err = _gamepad_apply_all(hdev, ally_cfg); + if (err) + return err; + + return 0; +} + +static int ally_hid_reset_resume(struct hid_device *hdev) +{ + int ep = get_endpoint_address(hdev); + if (ep != ROG_ALLY_CFG_INTF_IN) + return 0; + + ally_hid_init(hdev); + ally_rgb_resume(); + + return ally_hid_resume(hdev); +} + +static int ally_pm_thaw(struct device *dev) +{ + struct hid_device *hdev = to_hid_device(dev); + + return ally_hid_reset_resume(hdev); +} + +static int ally_pm_suspend(struct device *dev) +{ + if (drvdata.led_rgb_dev) { + ally_rgb_store_settings(); + } + + return 0; +} + +static const struct dev_pm_ops ally_pm_ops = { + .thaw = ally_pm_thaw, + .suspend = ally_pm_suspend, + .poweroff = ally_pm_suspend, +}; + +MODULE_DEVICE_TABLE(hid, rog_ally_devices); + +static struct hid_driver rog_ally_cfg = { .name = "asus_rog_ally", + .id_table = rog_ally_devices, + .probe = ally_hid_probe, + .remove = ally_hid_remove, + .raw_event = ally_raw_event, + /* HID is the better place for resume functions, not pm_ops */ + .resume = ally_hid_resume, + /* ALLy 1 requires this to reset device state correctly */ + .reset_resume = ally_hid_reset_resume, + .driver = { + .pm = &ally_pm_ops, + } +}; + +static int __init rog_ally_init(void) +{ + return hid_register_driver(&rog_ally_cfg); +} + +static void __exit rog_ally_exit(void) +{ + hid_unregister_driver(&rog_ally_cfg); +} + +module_init(rog_ally_init); +module_exit(rog_ally_exit); + +MODULE_IMPORT_NS("ASUS_WMI"); +MODULE_IMPORT_NS("HID_ASUS"); +MODULE_AUTHOR("Luke D. Jones"); +MODULE_DESCRIPTION("HID Driver for ASUS ROG Ally gamepad configuration."); +MODULE_LICENSE("GPL"); diff --git a/drivers/hid/hid-asus-ally.h b/drivers/hid/hid-asus-ally.h new file mode 100644 index 000000000000..c83817589082 --- /dev/null +++ b/drivers/hid/hid-asus-ally.h @@ -0,0 +1,398 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later + * + * HID driver for Asus ROG laptops and Ally + * + * Copyright (c) 2023 Luke Jones + */ + +#include +#include + +/* + * the xpad_mode is used inside the mode setting packet and is used + * for indexing (xpad_mode - 1) + */ +enum xpad_mode { + xpad_mode_game = 0x01, + xpad_mode_wasd = 0x02, + xpad_mode_mouse = 0x03, +}; + +/* the xpad_cmd determines which feature is set or queried */ +enum xpad_cmd { + xpad_cmd_set_mode = 0x01, + xpad_cmd_set_mapping = 0x02, + xpad_cmd_set_js_dz = 0x04, /* deadzones */ + xpad_cmd_set_tr_dz = 0x05, /* deadzones */ + xpad_cmd_set_vibe_intensity = 0x06, + xpad_cmd_set_leds = 0x08, + xpad_cmd_check_ready = 0x0A, + xpad_cmd_set_turbo = 0x0F, + xpad_cmd_set_response_curve = 0x13, + xpad_cmd_set_adz = 0x18, +}; + +/* the xpad_cmd determines which feature is set or queried */ +enum xpad_cmd_len { + xpad_cmd_len_mode = 0x01, + xpad_cmd_len_mapping = 0x2c, + xpad_cmd_len_deadzone = 0x04, + xpad_cmd_len_vibe_intensity = 0x02, + xpad_cmd_len_leds = 0x0C, + xpad_cmd_len_turbo = 0x20, + xpad_cmd_len_response_curve = 0x09, + xpad_cmd_len_adz = 0x02, +}; + +/* Values correspond to the actual HID byte value required */ +enum btn_pair_index { + btn_pair_dpad_u_d = 0x01, + btn_pair_dpad_l_r = 0x02, + btn_pair_ls_rs = 0x03, + btn_pair_lb_rb = 0x04, + btn_pair_a_b = 0x05, + btn_pair_x_y = 0x06, + btn_pair_view_menu = 0x07, + btn_pair_m1_m2 = 0x08, + btn_pair_lt_rt = 0x09, +}; + +#define BTN_PAD_A 0x0101000000000000 +#define BTN_PAD_B 0x0102000000000000 +#define BTN_PAD_X 0x0103000000000000 +#define BTN_PAD_Y 0x0104000000000000 +#define BTN_PAD_LB 0x0105000000000000 +#define BTN_PAD_RB 0x0106000000000000 +#define BTN_PAD_LS 0x0107000000000000 +#define BTN_PAD_RS 0x0108000000000000 +#define BTN_PAD_DPAD_UP 0x0109000000000000 +#define BTN_PAD_DPAD_DOWN 0x010A000000000000 +#define BTN_PAD_DPAD_LEFT 0x010B000000000000 +#define BTN_PAD_DPAD_RIGHT 0x010C000000000000 +#define BTN_PAD_LT 0x010D000000000000 +#define BTN_PAD_RT 0x010E000000000000 +#define BTN_PAD_VIEW 0x0111000000000000 +#define BTN_PAD_MENU 0x0112000000000000 +#define BTN_PAD_XBOX 0x0113000000000000 + +#define BTN_KB_M2 0x02008E0000000000 +#define BTN_KB_M1 0x02008F0000000000 +#define BTN_KB_ESC 0x0200760000000000 +#define BTN_KB_F1 0x0200500000000000 +#define BTN_KB_F2 0x0200600000000000 +#define BTN_KB_F3 0x0200400000000000 +#define BTN_KB_F4 0x02000C0000000000 +#define BTN_KB_F5 0x0200030000000000 +#define BTN_KB_F6 0x02000B0000000000 +#define BTN_KB_F7 0x0200800000000000 +#define BTN_KB_F8 0x02000A0000000000 +#define BTN_KB_F9 0x0200010000000000 +#define BTN_KB_F10 0x0200090000000000 +#define BTN_KB_F11 0x0200780000000000 +#define BTN_KB_F12 0x0200070000000000 +#define BTN_KB_F14 0x0200180000000000 +#define BTN_KB_F15 0x0200100000000000 +#define BTN_KB_BACKTICK 0x02000E0000000000 +#define BTN_KB_1 0x0200160000000000 +#define BTN_KB_2 0x02001E0000000000 +#define BTN_KB_3 0x0200260000000000 +#define BTN_KB_4 0x0200250000000000 +#define BTN_KB_5 0x02002E0000000000 +#define BTN_KB_6 0x0200360000000000 +#define BTN_KB_7 0x02003D0000000000 +#define BTN_KB_8 0x02003E0000000000 +#define BTN_KB_9 0x0200460000000000 +#define BTN_KB_0 0x0200450000000000 +#define BTN_KB_HYPHEN 0x02004E0000000000 +#define BTN_KB_EQUALS 0x0200550000000000 +#define BTN_KB_BACKSPACE 0x0200660000000000 +#define BTN_KB_TAB 0x02000D0000000000 +#define BTN_KB_Q 0x0200150000000000 +#define BTN_KB_W 0x02001D0000000000 +#define BTN_KB_E 0x0200240000000000 +#define BTN_KB_R 0x02002D0000000000 +#define BTN_KB_T 0x02002C0000000000 +#define BTN_KB_Y 0x0200350000000000 +#define BTN_KB_U 0x02003C0000000000 +#define BTN_KB_O 0x0200440000000000 +#define BTN_KB_P 0x02004D0000000000 +#define BTN_KB_LBRACKET 0x0200540000000000 +#define BTN_KB_RBRACKET 0x02005B0000000000 +#define BTN_KB_BACKSLASH 0x02005D0000000000 +#define BTN_KB_CAPS 0x0200580000000000 +#define BTN_KB_A 0x02001C0000000000 +#define BTN_KB_S 0x02001B0000000000 +#define BTN_KB_D 0x0200230000000000 +#define BTN_KB_F 0x02002B0000000000 +#define BTN_KB_G 0x0200340000000000 +#define BTN_KB_H 0x0200330000000000 +#define BTN_KB_J 0x02003B0000000000 +#define BTN_KB_K 0x0200420000000000 +#define BTN_KB_L 0x02004B0000000000 +#define BTN_KB_SEMI 0x02004C0000000000 +#define BTN_KB_QUOTE 0x0200520000000000 +#define BTN_KB_RET 0x02005A0000000000 +#define BTN_KB_LSHIFT 0x0200880000000000 +#define BTN_KB_Z 0x02001A0000000000 +#define BTN_KB_X 0x0200220000000000 +#define BTN_KB_C 0x0200210000000000 +#define BTN_KB_V 0x02002A0000000000 +#define BTN_KB_B 0x0200320000000000 +#define BTN_KB_N 0x0200310000000000 +#define BTN_KB_M 0x02003A0000000000 +#define BTN_KB_COMMA 0x0200410000000000 +#define BTN_KB_PERIOD 0x0200490000000000 +#define BTN_KB_RSHIFT 0x0200890000000000 +#define BTN_KB_LCTL 0x02008C0000000000 +#define BTN_KB_META 0x0200820000000000 +#define BTN_KB_LALT 0x02008A0000000000 +#define BTN_KB_SPACE 0x0200290000000000 +#define BTN_KB_RALT 0x02008B0000000000 +#define BTN_KB_MENU 0x0200840000000000 +#define BTN_KB_RCTL 0x02008D0000000000 +#define BTN_KB_PRNTSCN 0x0200C30000000000 +#define BTN_KB_SCRLCK 0x02007E0000000000 +#define BTN_KB_PAUSE 0x0200910000000000 +#define BTN_KB_INS 0x0200C20000000000 +#define BTN_KB_HOME 0x0200940000000000 +#define BTN_KB_PGUP 0x0200960000000000 +#define BTN_KB_DEL 0x0200C00000000000 +#define BTN_KB_END 0x0200950000000000 +#define BTN_KB_PGDWN 0x0200970000000000 +#define BTN_KB_UP_ARROW 0x0200980000000000 +#define BTN_KB_DOWN_ARROW 0x0200990000000000 +#define BTN_KB_LEFT_ARROW 0x0200910000000000 +#define BTN_KB_RIGHT_ARROW 0x02009B0000000000 + +#define BTN_NUMPAD_LOCK 0x0200770000000000 +#define BTN_NUMPAD_FWDSLASH 0x0200900000000000 +#define BTN_NUMPAD_ASTERISK 0x02007C0000000000 +#define BTN_NUMPAD_HYPHEN 0x02007B0000000000 +#define BTN_NUMPAD_0 0x0200700000000000 +#define BTN_NUMPAD_1 0x0200690000000000 +#define BTN_NUMPAD_2 0x0200720000000000 +#define BTN_NUMPAD_3 0x02007A0000000000 +#define BTN_NUMPAD_4 0x02006B0000000000 +#define BTN_NUMPAD_5 0x0200730000000000 +#define BTN_NUMPAD_6 0x0200740000000000 +#define BTN_NUMPAD_7 0x02006C0000000000 +#define BTN_NUMPAD_8 0x0200750000000000 +#define BTN_NUMPAD_9 0x02007D0000000000 +#define BTN_NUMPAD_PLUS 0x0200790000000000 +#define BTN_NUMPAD_ENTER 0x0200810000000000 +#define BTN_NUMPAD_PERIOD 0x0200710000000000 + +#define BTN_MOUSE_LCLICK 0x0300000001000000 +#define BTN_MOUSE_RCLICK 0x0300000002000000 +#define BTN_MOUSE_MCLICK 0x0300000003000000 +#define BTN_MOUSE_WHEEL_UP 0x0300000004000000 +#define BTN_MOUSE_WHEEL_DOWN 0x0300000005000000 + +#define BTN_MEDIA_SCREENSHOT 0x0500001600000000 +#define BTN_MEDIA_SHOW_KEYBOARD 0x0500001900000000 +#define BTN_MEDIA_SHOW_DESKTOP 0x0500001C00000000 +#define BTN_MEDIA_START_RECORDING 0x0500001E00000000 +#define BTN_MEDIA_MIC_OFF 0x0500000100000000 +#define BTN_MEDIA_VOL_DOWN 0x0500000200000000 +#define BTN_MEDIA_VOL_UP 0x0500000300000000 + +#define ALLY_DEVICE_ATTR_WO(_name, _sysfs_name) \ + struct device_attribute dev_attr_##_name = \ + __ATTR(_sysfs_name, 0200, NULL, _name##_store) + +/* required so we can have nested attributes with same name but different functions */ +#define ALLY_DEVICE_ATTR_RW(_name, _sysfs_name) \ + struct device_attribute dev_attr_##_name = \ + __ATTR(_sysfs_name, 0644, _name##_show, _name##_store) + +#define ALLY_DEVICE_ATTR_RO(_name, _sysfs_name) \ + struct device_attribute dev_attr_##_name = \ + __ATTR(_sysfs_name, 0444, _name##_show, NULL) + +/* button specific macros */ +#define ALLY_BTN_SHOW(_fname, _btn_name, _secondary) \ + static ssize_t _fname##_show(struct device *dev, \ + struct device_attribute *attr, char *buf) \ + { \ + struct ally_gamepad_cfg *ally_cfg = drvdata.gamepad_cfg; \ + struct btn_data *btn; \ + const char* name; \ + if (!drvdata.gamepad_cfg) \ + return -ENODEV; \ + btn = &ally_cfg->key_mapping[ally_cfg->mode - 1]._btn_name; \ + name = btn_to_name(_secondary ? btn->macro : btn->button); \ + return sysfs_emit(buf, "%s\n", name); \ + } + +#define ALLY_BTN_STORE(_fname, _btn_name, _secondary) \ + static ssize_t _fname##_store(struct device *dev, \ + struct device_attribute *attr, \ + const char *buf, size_t count) \ + { \ + struct ally_gamepad_cfg *ally_cfg = drvdata.gamepad_cfg; \ + struct btn_data *btn; \ + u64 code; \ + if (!drvdata.gamepad_cfg) \ + return -ENODEV; \ + btn = &ally_cfg->key_mapping[ally_cfg->mode - 1]._btn_name; \ + code = name_to_btn(buf); \ + if (_secondary) \ + btn->macro = code; \ + else \ + btn->button = code; \ + return count; \ + } + +#define ALLY_TURBO_SHOW(_fname, _btn_name) \ + static ssize_t _fname##_show(struct device *dev, \ + struct device_attribute *attr, char *buf) \ + { \ + struct ally_gamepad_cfg *ally_cfg = drvdata.gamepad_cfg; \ + struct btn_data *btn; \ + if (!drvdata.gamepad_cfg) \ + return -ENODEV; \ + btn = &ally_cfg->key_mapping[ally_cfg->mode - 1]._btn_name; \ + return sysfs_emit(buf, "%d\n", btn->turbo); \ + } + +#define ALLY_TURBO_STORE(_fname, _btn_name) \ + static ssize_t _fname##_store(struct device *dev, \ + struct device_attribute *attr, \ + const char *buf, size_t count) \ + { \ + struct ally_gamepad_cfg *ally_cfg = drvdata.gamepad_cfg; \ + struct btn_data *btn; \ + bool turbo; \ + int ret; \ + if (!drvdata.gamepad_cfg) \ + return -ENODEV; \ + btn = &ally_cfg->key_mapping[ally_cfg->mode - 1]._btn_name; \ + ret = kstrtobool(buf, &turbo); \ + if (ret) \ + return ret; \ + btn->turbo = turbo; \ + return count; \ + } + +#define ALLY_DEADZONE_SHOW(_fname, _axis_name) \ + static ssize_t _fname##_show(struct device *dev, \ + struct device_attribute *attr, char *buf) \ + { \ + struct ally_gamepad_cfg *ally_cfg = drvdata.gamepad_cfg; \ + struct deadzone *dz; \ + if (!drvdata.gamepad_cfg) \ + return -ENODEV; \ + dz = &ally_cfg->_axis_name; \ + return sysfs_emit(buf, "%d %d\n", dz->inner, dz->outer); \ + } + +#define ALLY_DEADZONE_STORE(_fname, _axis_name) \ + static ssize_t _fname##_store(struct device *dev, \ + struct device_attribute *attr, \ + const char *buf, size_t count) \ + { \ + struct ally_gamepad_cfg *ally_cfg = drvdata.gamepad_cfg; \ + struct hid_device *hdev = to_hid_device(dev); \ + u32 inner, outer; \ + if (!drvdata.gamepad_cfg) \ + return -ENODEV; \ + if (sscanf(buf, "%d %d", &inner, &outer) != 2) \ + return -EINVAL; \ + if (inner > 64 || outer > 64 || inner > outer) \ + return -EINVAL; \ + ally_cfg->_axis_name.inner = inner; \ + ally_cfg->_axis_name.outer = outer; \ + _gamepad_apply_deadzones(hdev, ally_cfg); \ + return count; \ + } + +#define ALLY_DEADZONES(_fname, _mname) \ + ALLY_DEADZONE_SHOW(_fname##_deadzone, _mname); \ + ALLY_DEADZONE_STORE(_fname##_deadzone, _mname); \ + ALLY_DEVICE_ATTR_RW(_fname##_deadzone, deadzone) + +/* response curve macros */ +#define ALLY_RESP_CURVE_SHOW(_fname, _mname) \ +static ssize_t _fname##_show(struct device *dev, \ + struct device_attribute *attr, \ + char *buf) \ + { \ + struct ally_gamepad_cfg *ally_cfg = drvdata.gamepad_cfg; \ + if (!drvdata.gamepad_cfg) \ + return -ENODEV; \ + return sysfs_emit(buf, "%d\n", ally_cfg->ls_rc._mname); \ + } + +#define ALLY_RESP_CURVE_STORE(_fname, _mname) \ +static ssize_t _fname##_store(struct device *dev, \ + struct device_attribute *attr, \ + const char *buf, size_t count) \ + { \ + struct ally_gamepad_cfg *ally_cfg = drvdata.gamepad_cfg; \ + int ret, val; \ + if (!drvdata.gamepad_cfg) \ + return -ENODEV; \ + ret = kstrtoint(buf, 0, &val); \ + if (ret) \ + return ret; \ + if (val < 0 || val > 100) \ + return -EINVAL; \ + ally_cfg->ls_rc._mname = val; \ + return count; \ + } + +/* _point_n must start at 1 */ +#define ALLY_JS_RC_POINT(_fname, _mname, _num) \ + ALLY_RESP_CURVE_SHOW(_fname##_##_mname##_##_num, _mname##_pct_##_num); \ + ALLY_RESP_CURVE_STORE(_fname##_##_mname##_##_num, _mname##_pct_##_num); \ + ALLY_DEVICE_ATTR_RW(_fname##_##_mname##_##_num, curve_##_mname##_pct_##_num) + +#define ALLY_BTN_ATTRS_GROUP(_name, _fname) \ + static struct attribute *_fname##_attrs[] = { \ + &dev_attr_##_fname.attr, \ + &dev_attr_##_fname##_macro.attr, \ + }; \ + static const struct attribute_group _fname##_attr_group = { \ + .name = __stringify(_name), \ + .attrs = _fname##_attrs, \ + } + +#define _ALLY_BTN_REMAP(_fname, _btn_name) \ + ALLY_BTN_SHOW(btn_mapping_##_fname##_remap, _btn_name, false); \ + ALLY_BTN_STORE(btn_mapping_##_fname##_remap, _btn_name, false); \ + ALLY_DEVICE_ATTR_RW(btn_mapping_##_fname##_remap, remap); + +#define _ALLY_BTN_MACRO(_fname, _btn_name) \ + ALLY_BTN_SHOW(btn_mapping_##_fname##_macro, _btn_name, true); \ + ALLY_BTN_STORE(btn_mapping_##_fname##_macro, _btn_name, true); \ + ALLY_DEVICE_ATTR_RW(btn_mapping_##_fname##_macro, macro_remap); + +#define ALLY_BTN_MAPPING(_fname, _btn_name) \ + _ALLY_BTN_REMAP(_fname, _btn_name) \ + _ALLY_BTN_MACRO(_fname, _btn_name) \ + static struct attribute *_fname##_attrs[] = { \ + &dev_attr_btn_mapping_##_fname##_remap.attr, \ + &dev_attr_btn_mapping_##_fname##_macro.attr, \ + NULL, \ + }; \ + static const struct attribute_group btn_mapping_##_fname##_attr_group = { \ + .name = __stringify(btn_##_fname), \ + .attrs = _fname##_attrs, \ + } + +#define ALLY_TURBO_BTN_MAPPING(_fname, _btn_name) \ + _ALLY_BTN_REMAP(_fname, _btn_name) \ + _ALLY_BTN_MACRO(_fname, _btn_name) \ + ALLY_TURBO_SHOW(btn_mapping_##_fname##_turbo, _btn_name); \ + ALLY_TURBO_STORE(btn_mapping_##_fname##_turbo, _btn_name); \ + ALLY_DEVICE_ATTR_RW(btn_mapping_##_fname##_turbo, turbo); \ + static struct attribute *_fname##_turbo_attrs[] = { \ + &dev_attr_btn_mapping_##_fname##_remap.attr, \ + &dev_attr_btn_mapping_##_fname##_macro.attr, \ + &dev_attr_btn_mapping_##_fname##_turbo.attr, \ + NULL, \ + }; \ + static const struct attribute_group btn_mapping_##_fname##_attr_group = { \ + .name = __stringify(btn_##_fname), \ + .attrs = _fname##_turbo_attrs, \ + } diff --git a/drivers/hid/hid-asus.c b/drivers/hid/hid-asus.c index 46e3e42f9eb5..1a2460922608 100644 --- a/drivers/hid/hid-asus.c +++ b/drivers/hid/hid-asus.c @@ -23,6 +23,7 @@ /* */ +#include "linux/export.h" #include #include #include @@ -33,6 +34,7 @@ #include #include "hid-ids.h" +#include "hid-asus.h" MODULE_AUTHOR("Yusuke Fujimaki "); MODULE_AUTHOR("Brendan McGrath "); @@ -52,6 +54,10 @@ MODULE_DESCRIPTION("Asus HID Keyboard and TouchPad"); #define FEATURE_KBD_LED_REPORT_ID1 0x5d #define FEATURE_KBD_LED_REPORT_ID2 0x5e +#define ROG_ALLY_REPORT_SIZE 64 +#define ROG_ALLY_X_MIN_MCU 313 +#define ROG_ALLY_MIN_MCU 319 + #define SUPPORT_KBD_BACKLIGHT BIT(0) #define MAX_TOUCH_MAJOR 8 @@ -84,6 +90,7 @@ MODULE_DESCRIPTION("Asus HID Keyboard and TouchPad"); #define QUIRK_MEDION_E1239T BIT(10) #define QUIRK_ROG_NKEY_KEYBOARD BIT(11) #define QUIRK_ROG_CLAYMORE_II_KEYBOARD BIT(12) +#define QUIRK_ROG_ALLY_XPAD BIT(13) #define I2C_KEYBOARD_QUIRKS (QUIRK_FIX_NOTEBOOK_REPORT | \ QUIRK_NO_INIT_REPORTS | \ @@ -534,6 +541,98 @@ static bool asus_kbd_wmi_led_control_present(struct hid_device *hdev) return !!(value & ASUS_WMI_DSTS_PRESENCE_BIT); } +/* + * We don't care about any other part of the string except the version section. + * Example strings: FGA80100.RC72LA.312_T01, FGA80100.RC71LS.318_T01 + * The bytes "5a 05 03 31 00 1a 13" and possibly more come before the version + * string, and there may be additional bytes after the version string such as + * "75 00 74 00 65 00" or a postfix such as "_T01" + */ +static int mcu_parse_version_string(const u8 *response, size_t response_size) +{ + const u8 *end = response + response_size; + const u8 *p = response; + int dots, err, version; + char buf[4]; + + dots = 0; + while (p < end && dots < 2) { + if (*p++ == '.') + dots++; + } + + if (dots != 2 || p >= end || (p + 3) >= end) + return -EINVAL; + + memcpy(buf, p, 3); + buf[3] = '\0'; + + err = kstrtoint(buf, 10, &version); + if (err || version < 0) + return -EINVAL; + + return version; +} + +static int mcu_request_version(struct hid_device *hdev) +{ + u8 *response __free(kfree) = kzalloc(ROG_ALLY_REPORT_SIZE, GFP_KERNEL); + const u8 request[] = { 0x5a, 0x05, 0x03, 0x31, 0x00, 0x20 }; + int ret; + + if (!response) + return -ENOMEM; + + ret = asus_kbd_set_report(hdev, request, sizeof(request)); + if (ret < 0) + return ret; + + ret = hid_hw_raw_request(hdev, FEATURE_REPORT_ID, response, + ROG_ALLY_REPORT_SIZE, HID_FEATURE_REPORT, + HID_REQ_GET_REPORT); + if (ret < 0) + return ret; + + ret = mcu_parse_version_string(response, ROG_ALLY_REPORT_SIZE); + if (ret < 0) { + pr_err("Failed to parse MCU version: %d\n", ret); + print_hex_dump(KERN_ERR, "MCU: ", DUMP_PREFIX_NONE, + 16, 1, response, ROG_ALLY_REPORT_SIZE, false); + } + + return ret; +} + +void validate_mcu_fw_version(struct hid_device *hdev, int idProduct) +{ + int min_version, version; + + version = mcu_request_version(hdev); + if (version < 0) + return; + + switch (idProduct) { + case USB_DEVICE_ID_ASUSTEK_ROG_NKEY_ALLY: + min_version = ROG_ALLY_MIN_MCU; + break; + case USB_DEVICE_ID_ASUSTEK_ROG_NKEY_ALLY_X: + min_version = ROG_ALLY_X_MIN_MCU; + break; + default: + min_version = 0; + } + + if (version < min_version) { + hid_warn(hdev, + "The MCU firmware version must be %d or greater to avoid issues with suspend.\n", + min_version); + } else { + set_ally_mcu_hack(ASUS_WMI_ALLY_MCU_HACK_DISABLED); + set_ally_mcu_powersave(true); + } +} +EXPORT_SYMBOL_NS(validate_mcu_fw_version, "HID_ASUS"); + static int asus_kbd_register_leds(struct hid_device *hdev) { struct asus_drvdata *drvdata = hid_get_drvdata(hdev); @@ -560,6 +659,16 @@ static int asus_kbd_register_leds(struct hid_device *hdev) if (ret < 0) return ret; } + + #if !IS_REACHABLE(CONFIG_HID_ASUS_ALLY) + if (drvdata->quirks & QUIRK_ROG_ALLY_XPAD) { + struct usb_interface *intf = to_usb_interface(hdev->dev.parent); + struct usb_device *udev = interface_to_usbdev(intf); + validate_mcu_fw_version(hdev, + le16_to_cpu(udev->descriptor.idProduct)); + } + #endif /* !IS_REACHABLE(CONFIG_HID_ASUS_ALLY) */ + } else { /* Initialize keyboard */ ret = asus_kbd_init(hdev, FEATURE_KBD_REPORT_ID); @@ -1016,8 +1125,10 @@ static int __maybe_unused asus_reset_resume(struct hid_device *hdev) static int asus_probe(struct hid_device *hdev, const struct hid_device_id *id) { - int ret; struct asus_drvdata *drvdata; + struct usb_host_endpoint *ep; + struct usb_interface *intf; + int ret; drvdata = devm_kzalloc(&hdev->dev, sizeof(*drvdata), GFP_KERNEL); if (drvdata == NULL) { @@ -1029,6 +1140,18 @@ static int asus_probe(struct hid_device *hdev, const struct hid_device_id *id) drvdata->quirks = id->driver_data; + /* Ignore these endpoints as they are used by hid-asus-ally */ + #if IS_REACHABLE(CONFIG_HID_ASUS_ALLY) + if (drvdata->quirks & QUIRK_ROG_ALLY_XPAD) { + intf = to_usb_interface(hdev->dev.parent); + ep = intf->cur_altsetting->endpoint; + if (ep->desc.bEndpointAddress == ROG_ALLY_X_INTF_IN || + ep->desc.bEndpointAddress == ROG_ALLY_CFG_INTF_IN || + ep->desc.bEndpointAddress == ROG_ALLY_CFG_INTF_OUT) + return -ENODEV; + } + #endif /* IS_REACHABLE(CONFIG_HID_ASUS_ALLY) */ + /* * T90CHI's keyboard dock returns same ID values as T100CHI's dock. * Thus, identify T90CHI dock with product name string. @@ -1280,10 +1403,10 @@ static const struct hid_device_id asus_devices[] = { QUIRK_USE_KBD_BACKLIGHT | QUIRK_ROG_NKEY_KEYBOARD }, { HID_USB_DEVICE(USB_VENDOR_ID_ASUSTEK, USB_DEVICE_ID_ASUSTEK_ROG_NKEY_ALLY), - QUIRK_USE_KBD_BACKLIGHT | QUIRK_ROG_NKEY_KEYBOARD }, + QUIRK_USE_KBD_BACKLIGHT | QUIRK_ROG_NKEY_KEYBOARD | QUIRK_ROG_ALLY_XPAD}, { HID_USB_DEVICE(USB_VENDOR_ID_ASUSTEK, USB_DEVICE_ID_ASUSTEK_ROG_NKEY_ALLY_X), - QUIRK_USE_KBD_BACKLIGHT | QUIRK_ROG_NKEY_KEYBOARD }, + QUIRK_USE_KBD_BACKLIGHT | QUIRK_ROG_NKEY_KEYBOARD | QUIRK_ROG_ALLY_XPAD }, { HID_USB_DEVICE(USB_VENDOR_ID_ASUSTEK, USB_DEVICE_ID_ASUSTEK_ROG_CLAYMORE_II_KEYBOARD), QUIRK_ROG_CLAYMORE_II_KEYBOARD }, @@ -1327,4 +1450,5 @@ static struct hid_driver asus_driver = { }; module_hid_driver(asus_driver); +MODULE_IMPORT_NS("ASUS_WMI"); MODULE_LICENSE("GPL"); diff --git a/drivers/hid/hid-asus.h b/drivers/hid/hid-asus.h new file mode 100644 index 000000000000..f67dd5a3a1bc --- /dev/null +++ b/drivers/hid/hid-asus.h @@ -0,0 +1,13 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef __HID_ASUS_H +#define __HID_ASUS_H + +#include + +#define ROG_ALLY_CFG_INTF_IN 0x83 +#define ROG_ALLY_CFG_INTF_OUT 0x04 +#define ROG_ALLY_X_INTF_IN 0x87 + +void validate_mcu_fw_version(struct hid_device *hdev, int idProduct); + +#endif /* __HID_ASUS_H */ diff --git a/drivers/hid/hid-ids.h b/drivers/hid/hid-ids.h index 288a2b864cc4..50cd02b049fc 100644 --- a/drivers/hid/hid-ids.h +++ b/drivers/hid/hid-ids.h @@ -217,6 +217,7 @@ #define USB_DEVICE_ID_ASUSTEK_ROG_NKEY_KEYBOARD2 0x19b6 #define USB_DEVICE_ID_ASUSTEK_ROG_NKEY_KEYBOARD3 0x1a30 #define USB_DEVICE_ID_ASUSTEK_ROG_Z13_LIGHTBAR 0x18c6 +#define USB_DEVICE_ID_ASUSTEK_ROG_RAIKIRI_PAD 0x1abb #define USB_DEVICE_ID_ASUSTEK_ROG_NKEY_ALLY 0x1abe #define USB_DEVICE_ID_ASUSTEK_ROG_NKEY_ALLY_X 0x1b4c #define USB_DEVICE_ID_ASUSTEK_ROG_CLAYMORE_II_KEYBOARD 0x196b diff --git a/drivers/platform/x86/Kconfig b/drivers/platform/x86/Kconfig index 0258dd879d64..7edab99d3ae9 100644 --- a/drivers/platform/x86/Kconfig +++ b/drivers/platform/x86/Kconfig @@ -267,6 +267,18 @@ config ASUS_WIRELESS If you choose to compile this driver as a module the module will be called asus-wireless. +config ASUS_ARMOURY + tristate "ASUS Armoury driver" + depends on ASUS_WMI + select FW_ATTR_CLASS + help + Say Y here if you have a WMI aware Asus machine and would like to use the + firmware_attributes API to control various settings typically exposed in + the ASUS Armoury Crate application available on Windows. + + To compile this driver as a module, choose M here: the module will + be called asus-armoury. + config ASUS_WMI tristate "ASUS WMI Driver" depends on ACPI_WMI @@ -289,6 +301,17 @@ config ASUS_WMI To compile this driver as a module, choose M here: the module will be called asus-wmi. +config ASUS_WMI_DEPRECATED_ATTRS + bool "BIOS option support in WMI platform (DEPRECATED)" + depends on ASUS_WMI + default y + help + Say Y to expose the configurable BIOS options through the asus-wmi + driver. + + This can be used with or without the asus-armoury driver which + has the same attributes, but more, and better features. + config ASUS_NB_WMI tristate "Asus Notebook WMI Driver" depends on ASUS_WMI diff --git a/drivers/platform/x86/Makefile b/drivers/platform/x86/Makefile index e1b142947067..fe3e7e7dede8 100644 --- a/drivers/platform/x86/Makefile +++ b/drivers/platform/x86/Makefile @@ -32,6 +32,7 @@ obj-$(CONFIG_APPLE_GMUX) += apple-gmux.o # ASUS obj-$(CONFIG_ASUS_LAPTOP) += asus-laptop.o obj-$(CONFIG_ASUS_WIRELESS) += asus-wireless.o +obj-$(CONFIG_ASUS_ARMOURY) += asus-armoury.o obj-$(CONFIG_ASUS_WMI) += asus-wmi.o obj-$(CONFIG_ASUS_NB_WMI) += asus-nb-wmi.o obj-$(CONFIG_ASUS_TF103C_DOCK) += asus-tf103c-dock.o diff --git a/drivers/platform/x86/asus-armoury.c b/drivers/platform/x86/asus-armoury.c new file mode 100644 index 000000000000..84abc92bd365 --- /dev/null +++ b/drivers/platform/x86/asus-armoury.c @@ -0,0 +1,1202 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Asus Armoury (WMI) attributes driver. This driver uses the fw_attributes + * class to expose the various WMI functions that many gaming and some + * non-gaming ASUS laptops have available. + * These typically don't fit anywhere else in the sysfs such as under LED class, + * hwmon or other, and are set in Windows using the ASUS Armoury Crate tool. + * + * Copyright(C) 2024 Luke Jones + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "asus-armoury.h" +#include "firmware_attributes_class.h" + +#define ASUS_NB_WMI_EVENT_GUID "0B3CBB35-E3C2-45ED-91C2-4C5A6D195D1C" + +#define ASUS_MINI_LED_MODE_MASK 0x03 +/* Standard modes for devices with only on/off */ +#define ASUS_MINI_LED_OFF 0x00 +#define ASUS_MINI_LED_ON 0x01 +/* Like "on" but the effect is more vibrant or brighter */ +#define ASUS_MINI_LED_STRONG_MODE 0x02 +/* New modes for devices with 3 mini-led mode types */ +#define ASUS_MINI_LED_2024_WEAK 0x00 +#define ASUS_MINI_LED_2024_STRONG 0x01 +#define ASUS_MINI_LED_2024_OFF 0x02 + +/* Power tunable attribute name defines */ +#define ATTR_PPT_PL1_SPL "ppt_pl1_spl" +#define ATTR_PPT_PL2_SPPT "ppt_pl2_sppt" +#define ATTR_PPT_PL3_FPPT "ppt_pl3_fppt" +#define ATTR_PPT_APU_SPPT "ppt_apu_sppt" +#define ATTR_PPT_PLATFORM_SPPT "ppt_platform_sppt" +#define ATTR_NV_DYNAMIC_BOOST "nv_dynamic_boost" +#define ATTR_NV_TEMP_TARGET "nv_temp_target" +#define ATTR_NV_BASE_TGP "nv_base_tgp" +#define ATTR_NV_TGP "nv_tgp" + +#define ASUS_POWER_CORE_MASK GENMASK(15, 8) +#define ASUS_PERF_CORE_MASK GENMASK(7, 0) + +enum cpu_core_type { + CPU_CORE_PERF = 0, + CPU_CORE_POWER, +}; + +enum cpu_core_value { + CPU_CORE_DEFAULT = 0, + CPU_CORE_MIN, + CPU_CORE_MAX, + CPU_CORE_CURRENT, +}; + +#define CPU_PERF_CORE_COUNT_MIN 4 +#define CPU_POWR_CORE_COUNT_MIN 0 + +/* Tunables provided by ASUS for gaming laptops */ +struct cpu_cores { + u32 cur_perf_cores; + u32 min_perf_cores; + u32 max_perf_cores; + u32 cur_power_cores; + u32 min_power_cores; + u32 max_power_cores; +}; + +struct rog_tunables { + const struct power_limits *power_limits; + u32 ppt_pl1_spl; // cpu + u32 ppt_pl2_sppt; // cpu + u32 ppt_pl3_fppt; // cpu + u32 ppt_apu_sppt; // plat + u32 ppt_platform_sppt; // plat + + u32 nv_dynamic_boost; + u32 nv_temp_target; + u32 nv_tgp; +}; + +static struct asus_armoury_priv { + struct device *fw_attr_dev; + struct kset *fw_attr_kset; + + struct cpu_cores *cpu_cores; + /* Index 0 for DC, 1 for AC */ + struct rog_tunables *rog_tunables[2]; + u32 mini_led_dev_id; + u32 gpu_mux_dev_id; + /* + * Mutex to prevent big/little core count changes writing to same + * endpoint at the same time. Must lock during attr store. + */ + struct mutex cpu_core_mutex; +} asus_armoury = { + .cpu_core_mutex = __MUTEX_INITIALIZER(asus_armoury.cpu_core_mutex) +}; + +struct fw_attrs_group { + bool pending_reboot; +}; + +static struct fw_attrs_group fw_attrs = { + .pending_reboot = false, +}; + +struct asus_attr_group { + const struct attribute_group *attr_group; + u32 wmi_devid; +}; + +static bool asus_wmi_is_present(u32 dev_id) +{ + u32 retval; + int status; + + status = asus_wmi_evaluate_method(ASUS_WMI_METHODID_DSTS, dev_id, 0, &retval); + pr_debug("%s called (0x%08x), retval: 0x%08x\n", __func__, dev_id, retval); + + return status == 0 && (retval & ASUS_WMI_DSTS_PRESENCE_BIT); +} + +static void asus_set_reboot_and_signal_event(void) +{ + fw_attrs.pending_reboot = true; + kobject_uevent(&asus_armoury.fw_attr_dev->kobj, KOBJ_CHANGE); +} + +static ssize_t pending_reboot_show(struct kobject *kobj, struct kobj_attribute *attr, char *buf) +{ + return sysfs_emit(buf, "%d\n", fw_attrs.pending_reboot); +} + +static struct kobj_attribute pending_reboot = __ATTR_RO(pending_reboot); + +static bool asus_bios_requires_reboot(struct kobj_attribute *attr) +{ + return !strcmp(attr->attr.name, "gpu_mux_mode") || + !strcmp(attr->attr.name, "cores_performance") || + !strcmp(attr->attr.name, "cores_efficiency") || + !strcmp(attr->attr.name, "panel_hd_mode"); +} + +static int armoury_wmi_set_devstate(struct kobj_attribute *attr, u32 value, u32 wmi_dev) +{ + u32 result; + int err; + + err = asus_wmi_set_devstate(wmi_dev, value, &result); + if (err) { + pr_err("Failed to set %s: %d\n", attr->attr.name, err); + return err; + } + /* + * !1 is usually considered a fail by ASUS, but some WMI methods do use > 1 + * to return a status code or similar. + */ + if (result < 1) { + pr_err("Failed to set %s: (result): 0x%x\n", attr->attr.name, result); + return -EIO; + } + + return 0; +} + +/** + * attr_int_store() - Send an int to wmi method, checks if within min/max exclusive. + * @kobj: Pointer to the driver object. + * @attr: Pointer to the attribute calling this function. + * @buf: The buffer to read from, this is parsed to `int` type. + * @count: Required by sysfs attribute macros, pass in from the callee attr. + * @min: Minimum accepted value. Below this returns -EINVAL. + * @max: Maximum accepted value. Above this returns -EINVAL. + * @store_value: Pointer to where the parsed value should be stored. + * @wmi_dev: The WMI function ID to use. + * + * This function is intended to be generic so it can be called from any "_store" + * attribute which works only with integers. The integer to be sent to the WMI method + * is range checked and an error returned if out of range. + * + * If the value is valid and WMI is success, then the sysfs attribute is notified + * and if asus_bios_requires_reboot() is true then reboot attribute is also notified. + * + * Returns: Either count, or an error. + */ +static ssize_t attr_uint_store(struct kobject *kobj, struct kobj_attribute *attr, const char *buf, + size_t count, u32 min, u32 max, u32 *store_value, u32 wmi_dev) +{ + u32 value; + int err; + + err = kstrtouint(buf, 10, &value); + if (err) + return err; + + if (value < min || value > max) + return -EINVAL; + + err = armoury_wmi_set_devstate(attr, value, wmi_dev); + if (err) + return err; + + if (store_value != NULL) + *store_value = value; + sysfs_notify(kobj, NULL, attr->attr.name); + + if (asus_bios_requires_reboot(attr)) + asus_set_reboot_and_signal_event(); + + return count; +} + +static ssize_t enum_type_show(struct kobject *kobj, struct kobj_attribute *attr, + char *buf) +{ + return sysfs_emit(buf, "enumeration\n"); +} + +static ssize_t int_type_show(struct kobject *kobj, struct kobj_attribute *attr, + char *buf) +{ + return sysfs_emit(buf, "integer\n"); +} + +/* Mini-LED mode **************************************************************/ +static ssize_t mini_led_mode_current_value_show(struct kobject *kobj, + struct kobj_attribute *attr, char *buf) +{ + u32 value; + int err; + + err = asus_wmi_get_devstate_dsts(asus_armoury.mini_led_dev_id, &value); + if (err) + return err; + + value &= ASUS_MINI_LED_MODE_MASK; + + /* + * Remap the mode values to match previous generation mini-LED. The last gen + * WMI 0 == off, while on this version WMI 2 == off (flipped). + */ + if (asus_armoury.mini_led_dev_id == ASUS_WMI_DEVID_MINI_LED_MODE2) { + switch (value) { + case ASUS_MINI_LED_2024_WEAK: + value = ASUS_MINI_LED_ON; + break; + case ASUS_MINI_LED_2024_STRONG: + value = ASUS_MINI_LED_STRONG_MODE; + break; + case ASUS_MINI_LED_2024_OFF: + value = ASUS_MINI_LED_OFF; + break; + } + } + + return sysfs_emit(buf, "%u\n", value); +} + +static ssize_t mini_led_mode_current_value_store(struct kobject *kobj, + struct kobj_attribute *attr, + const char *buf, size_t count) +{ + u32 mode; + int err; + + err = kstrtou32(buf, 10, &mode); + if (err) + return err; + + if (asus_armoury.mini_led_dev_id == ASUS_WMI_DEVID_MINI_LED_MODE && + mode > ASUS_MINI_LED_ON) + return -EINVAL; + if (asus_armoury.mini_led_dev_id == ASUS_WMI_DEVID_MINI_LED_MODE2 && + mode > ASUS_MINI_LED_STRONG_MODE) + return -EINVAL; + + /* + * Remap the mode values so expected behaviour is the same as the last + * generation of mini-LED with 0 == off, 1 == on. + */ + if (asus_armoury.mini_led_dev_id == ASUS_WMI_DEVID_MINI_LED_MODE2) { + switch (mode) { + case ASUS_MINI_LED_OFF: + mode = ASUS_MINI_LED_2024_OFF; + break; + case ASUS_MINI_LED_ON: + mode = ASUS_MINI_LED_2024_WEAK; + break; + case ASUS_MINI_LED_STRONG_MODE: + mode = ASUS_MINI_LED_2024_STRONG; + break; + } + } + + err = armoury_wmi_set_devstate(attr, mode, asus_armoury.mini_led_dev_id); + if (err) + return err; + + sysfs_notify(kobj, NULL, attr->attr.name); + + return count; +} + +static ssize_t mini_led_mode_possible_values_show(struct kobject *kobj, + struct kobj_attribute *attr, char *buf) +{ + switch (asus_armoury.mini_led_dev_id) { + case ASUS_WMI_DEVID_MINI_LED_MODE: + return sysfs_emit(buf, "0;1\n"); + case ASUS_WMI_DEVID_MINI_LED_MODE2: + return sysfs_emit(buf, "0;1;2\n"); + } + + return sysfs_emit(buf, "0\n"); +} + +ATTR_GROUP_ENUM_CUSTOM(mini_led_mode, "mini_led_mode", "Set the mini-LED backlight mode"); + +static ssize_t gpu_mux_mode_current_value_store(struct kobject *kobj, + struct kobj_attribute *attr, const char *buf, + size_t count) +{ + int result, err; + u32 optimus; + + err = kstrtou32(buf, 10, &optimus); + if (err) + return err; + + if (optimus > 1) + return -EINVAL; + + if (asus_wmi_is_present(ASUS_WMI_DEVID_DGPU)) { + err = asus_wmi_get_devstate_dsts(ASUS_WMI_DEVID_DGPU, &result); + if (err) + return err; + if (result && !optimus) { + pr_warn("Can not switch MUX to dGPU mode when dGPU is disabled: %02X %02X\n", + result, optimus); + return -ENODEV; + } + } + + if (asus_wmi_is_present(ASUS_WMI_DEVID_EGPU)) { + err = asus_wmi_get_devstate_dsts(ASUS_WMI_DEVID_EGPU, &result); + if (err) + return err; + if (result && !optimus) { + pr_warn("Can not switch MUX to dGPU mode when eGPU is enabled\n"); + return -ENODEV; + } + } + + err = armoury_wmi_set_devstate(attr, optimus, asus_armoury.gpu_mux_dev_id); + if (err) + return err; + + sysfs_notify(kobj, NULL, attr->attr.name); + asus_set_reboot_and_signal_event(); + + return count; +} +WMI_SHOW_INT(gpu_mux_mode_current_value, "%d\n", asus_armoury.gpu_mux_dev_id); +ATTR_GROUP_BOOL_CUSTOM(gpu_mux_mode, "gpu_mux_mode", "Set the GPU display MUX mode"); + +/* + * A user may be required to store the value twice, typical store first, then + * rescan PCI bus to activate power, then store a second time to save correctly. + */ +static ssize_t dgpu_disable_current_value_store(struct kobject *kobj, + struct kobj_attribute *attr, const char *buf, + size_t count) +{ + int result, err; + u32 disable; + + err = kstrtou32(buf, 10, &disable); + if (err) + return err; + + if (disable > 1) + return -EINVAL; + + if (asus_armoury.gpu_mux_dev_id) { + err = asus_wmi_get_devstate_dsts(asus_armoury.gpu_mux_dev_id, &result); + if (err) + return err; + if (!result && disable) { + pr_warn("Can not disable dGPU when the MUX is in dGPU mode\n"); + return -ENODEV; + } + } + + err = armoury_wmi_set_devstate(attr, disable, ASUS_WMI_DEVID_DGPU); + if (err) + return err; + + sysfs_notify(kobj, NULL, attr->attr.name); + + return count; +} +WMI_SHOW_INT(dgpu_disable_current_value, "%d\n", ASUS_WMI_DEVID_DGPU); +ATTR_GROUP_BOOL_CUSTOM(dgpu_disable, "dgpu_disable", "Disable the dGPU"); + +/* The ACPI call to enable the eGPU also disables the internal dGPU */ +static ssize_t egpu_enable_current_value_store(struct kobject *kobj, struct kobj_attribute *attr, + const char *buf, size_t count) +{ + int result, err; + u32 enable; + + err = kstrtou32(buf, 10, &enable); + if (err) + return err; + + if (enable > 1) + return -EINVAL; + + err = asus_wmi_get_devstate_dsts(ASUS_WMI_DEVID_EGPU_CONNECTED, &result); + if (err) { + pr_warn("Failed to get eGPU connection status: %d\n", err); + return err; + } + + if (asus_armoury.gpu_mux_dev_id) { + err = asus_wmi_get_devstate_dsts(asus_armoury.gpu_mux_dev_id, &result); + if (err) { + pr_warn("Failed to get GPU MUX status: %d\n", result); + return result; + } + if (!result && enable) { + pr_warn("Can not enable eGPU when the MUX is in dGPU mode\n"); + return -ENODEV; + } + } + + err = armoury_wmi_set_devstate(attr, enable, ASUS_WMI_DEVID_EGPU); + if (err) + return err; + + sysfs_notify(kobj, NULL, attr->attr.name); + + return count; +} +WMI_SHOW_INT(egpu_enable_current_value, "%d\n", ASUS_WMI_DEVID_EGPU); +ATTR_GROUP_BOOL_CUSTOM(egpu_enable, "egpu_enable", "Enable the eGPU (also disables dGPU)"); + +/* Device memory available to APU */ + +static ssize_t apu_mem_current_value_show(struct kobject *kobj, struct kobj_attribute *attr, + char *buf) +{ + int err; + u32 mem; + + err = asus_wmi_get_devstate_dsts(ASUS_WMI_DEVID_APU_MEM, &mem); + if (err) + return err; + + switch (mem) { + case 0x100: + mem = 0; + break; + case 0x102: + mem = 1; + break; + case 0x103: + mem = 2; + break; + case 0x104: + mem = 3; + break; + case 0x105: + mem = 4; + break; + case 0x106: + /* This is out of order and looks wrong but is correct */ + mem = 8; + break; + case 0x107: + mem = 5; + break; + case 0x108: + mem = 6; + break; + case 0x109: + mem = 7; + break; + default: + mem = 4; + break; + } + + return sysfs_emit(buf, "%u\n", mem); +} + +static ssize_t apu_mem_current_value_store(struct kobject *kobj, struct kobj_attribute *attr, + const char *buf, size_t count) +{ + int result, err; + u32 requested, mem; + + result = kstrtou32(buf, 10, &requested); + if (result) + return result; + + switch (requested) { + case 0: + mem = 0x000; + break; + case 1: + mem = 0x102; + break; + case 2: + mem = 0x103; + break; + case 3: + mem = 0x104; + break; + case 4: + mem = 0x105; + break; + case 5: + mem = 0x107; + break; + case 6: + mem = 0x108; + break; + case 7: + mem = 0x109; + break; + case 8: + /* This is out of order and looks wrong but is correct */ + mem = 0x106; + break; + default: + return -EIO; + } + + err = asus_wmi_set_devstate(ASUS_WMI_DEVID_APU_MEM, mem, &result); + if (err) { + pr_warn("Failed to set apu_mem: %d\n", err); + return err; + } + + pr_info("APU memory changed to %uGB, reboot required\n", requested); + sysfs_notify(kobj, NULL, attr->attr.name); + + asus_set_reboot_and_signal_event(); + + return count; +} + +static ssize_t apu_mem_possible_values_show(struct kobject *kobj, struct kobj_attribute *attr, + char *buf) +{ + return sysfs_emit(buf, "0;1;2;3;4;5;6;7;8\n"); +} +ATTR_GROUP_ENUM_CUSTOM(apu_mem, "apu_mem", "Set available system RAM (in GB) for the APU to use"); + +static int init_max_cpu_cores(void) +{ + u32 cores; + int err; + + err = asus_wmi_get_devstate_dsts(ASUS_WMI_DEVID_CORES_MAX, &cores); + if (err) + return err; + + cores &= ~ASUS_WMI_DSTS_PRESENCE_BIT; + asus_armoury.cpu_cores->max_power_cores = FIELD_GET(ASUS_POWER_CORE_MASK, cores); + asus_armoury.cpu_cores->max_perf_cores = FIELD_GET(ASUS_PERF_CORE_MASK, cores); + + err = asus_wmi_get_devstate_dsts(ASUS_WMI_DEVID_CORES, &cores); + if (err) { + pr_err("Could not get CPU core count: error %d", err); + return err; + } + + asus_armoury.cpu_cores->cur_perf_cores = FIELD_GET(ASUS_PERF_CORE_MASK, cores); + asus_armoury.cpu_cores->cur_power_cores = FIELD_GET(ASUS_POWER_CORE_MASK, cores); + + asus_armoury.cpu_cores->min_perf_cores = CPU_PERF_CORE_COUNT_MIN; + asus_armoury.cpu_cores->min_power_cores = CPU_POWR_CORE_COUNT_MIN; + + return 0; +} + +static ssize_t cores_value_show(struct kobject *kobj, struct kobj_attribute *attr, char *buf, + enum cpu_core_type core_type, enum cpu_core_value core_value) +{ + u32 cores; + + switch (core_value) { + case CPU_CORE_DEFAULT: + case CPU_CORE_MAX: + if (core_type == CPU_CORE_PERF) + return sysfs_emit(buf, "%d\n", + asus_armoury.cpu_cores->max_perf_cores); + else + return sysfs_emit(buf, "%d\n", + asus_armoury.cpu_cores->max_power_cores); + case CPU_CORE_MIN: + if (core_type == CPU_CORE_PERF) + return sysfs_emit(buf, "%d\n", + asus_armoury.cpu_cores->min_perf_cores); + else + return sysfs_emit(buf, "%d\n", + asus_armoury.cpu_cores->min_power_cores); + default: + break; + } + + if (core_type == CPU_CORE_PERF) + cores = asus_armoury.cpu_cores->cur_perf_cores; + else + cores = asus_armoury.cpu_cores->cur_power_cores; + + return sysfs_emit(buf, "%d\n", cores); +} + +static ssize_t cores_current_value_store(struct kobject *kobj, struct kobj_attribute *attr, + const char *buf, enum cpu_core_type core_type) +{ + u32 new_cores, perf_cores, power_cores, out_val, min, max; + int result, err; + + result = kstrtou32(buf, 10, &new_cores); + if (result) + return result; + + mutex_lock(&asus_armoury.cpu_core_mutex); + + if (core_type == CPU_CORE_PERF) { + perf_cores = new_cores; + power_cores = out_val = asus_armoury.cpu_cores->cur_power_cores; + min = asus_armoury.cpu_cores->min_perf_cores; + max = asus_armoury.cpu_cores->max_perf_cores; + } else { + perf_cores = asus_armoury.cpu_cores->cur_perf_cores; + power_cores = out_val = new_cores; + min = asus_armoury.cpu_cores->min_power_cores; + max = asus_armoury.cpu_cores->max_power_cores; + } + + if (new_cores < min || new_cores > max) { + mutex_unlock(&asus_armoury.cpu_core_mutex); + return -EINVAL; + } + + out_val = 0; + out_val |= FIELD_PREP(ASUS_PERF_CORE_MASK, perf_cores); + out_val |= FIELD_PREP(ASUS_POWER_CORE_MASK, power_cores); + + err = asus_wmi_set_devstate(ASUS_WMI_DEVID_CORES, out_val, &result); + + if (err) { + pr_warn("Failed to set CPU core count: %d\n", err); + mutex_unlock(&asus_armoury.cpu_core_mutex); + return err; + } + + if (result > 1) { + pr_warn("Failed to set CPU core count (result): 0x%x\n", result); + mutex_unlock(&asus_armoury.cpu_core_mutex); + return -EIO; + } + + pr_info("CPU core count changed, reboot required\n"); + mutex_unlock(&asus_armoury.cpu_core_mutex); + + sysfs_notify(kobj, NULL, attr->attr.name); + asus_set_reboot_and_signal_event(); + + return 0; +} + +static ssize_t cores_performance_min_value_show(struct kobject *kobj, + struct kobj_attribute *attr, char *buf) +{ + return cores_value_show(kobj, attr, buf, CPU_CORE_PERF, CPU_CORE_MIN); +} + +static ssize_t cores_performance_max_value_show(struct kobject *kobj, + struct kobj_attribute *attr, char *buf) +{ + return cores_value_show(kobj, attr, buf, CPU_CORE_PERF, CPU_CORE_MAX); +} + +static ssize_t cores_performance_default_value_show(struct kobject *kobj, + struct kobj_attribute *attr, char *buf) +{ + return cores_value_show(kobj, attr, buf, CPU_CORE_PERF, CPU_CORE_DEFAULT); +} + +static ssize_t cores_performance_current_value_show(struct kobject *kobj, + struct kobj_attribute *attr, char *buf) +{ + return cores_value_show(kobj, attr, buf, CPU_CORE_PERF, CPU_CORE_CURRENT); +} + +static ssize_t cores_performance_current_value_store(struct kobject *kobj, + struct kobj_attribute *attr, + const char *buf, size_t count) +{ + int err; + + err = cores_current_value_store(kobj, attr, buf, CPU_CORE_PERF); + if (err) + return err; + + return count; +} +ATTR_GROUP_CORES_RW(cores_performance, "cores_performance", + "Set the max available performance cores"); + +static ssize_t cores_efficiency_min_value_show(struct kobject *kobj, struct kobj_attribute *attr, + char *buf) +{ + return cores_value_show(kobj, attr, buf, CPU_CORE_POWER, CPU_CORE_MIN); +} + +static ssize_t cores_efficiency_max_value_show(struct kobject *kobj, struct kobj_attribute *attr, + char *buf) +{ + return cores_value_show(kobj, attr, buf, CPU_CORE_POWER, CPU_CORE_MAX); +} + +static ssize_t cores_efficiency_default_value_show(struct kobject *kobj, + struct kobj_attribute *attr, char *buf) +{ + return cores_value_show(kobj, attr, buf, CPU_CORE_POWER, CPU_CORE_DEFAULT); +} + +static ssize_t cores_efficiency_current_value_show(struct kobject *kobj, + struct kobj_attribute *attr, char *buf) +{ + return cores_value_show(kobj, attr, buf, CPU_CORE_POWER, CPU_CORE_CURRENT); +} + +static ssize_t cores_efficiency_current_value_store(struct kobject *kobj, + struct kobj_attribute *attr, const char *buf, + size_t count) +{ + int err; + + err = cores_current_value_store(kobj, attr, buf, CPU_CORE_POWER); + if (err) + return err; + + return count; +} +ATTR_GROUP_CORES_RW(cores_efficiency, "cores_efficiency", + "Set the max available efficiency cores"); + +/* Define helper to access the current power mode tunable values */ +static inline struct rog_tunables *get_current_tunables(void) +{ + return asus_armoury + .rog_tunables[power_supply_is_system_supplied() ? 1 : 0]; +} + +/* Simple attribute creation */ +ATTR_GROUP_ROG_TUNABLE(ppt_pl1_spl, ATTR_PPT_PL1_SPL, ASUS_WMI_DEVID_PPT_PL1_SPL, + "Set the CPU slow package limit"); +ATTR_GROUP_ROG_TUNABLE(ppt_pl2_sppt, ATTR_PPT_PL2_SPPT, ASUS_WMI_DEVID_PPT_PL2_SPPT, + "Set the CPU fast package limit"); +ATTR_GROUP_ROG_TUNABLE(ppt_pl3_fppt, ATTR_PPT_PL3_FPPT, ASUS_WMI_DEVID_PPT_FPPT, + "Set the CPU fastest package limit"); +ATTR_GROUP_ROG_TUNABLE(ppt_apu_sppt, ATTR_PPT_APU_SPPT, ASUS_WMI_DEVID_PPT_APU_SPPT, + "Set the APU package limit"); +ATTR_GROUP_ROG_TUNABLE(ppt_platform_sppt, ATTR_PPT_PLATFORM_SPPT, ASUS_WMI_DEVID_PPT_PLAT_SPPT, + "Set the platform package limit"); +ATTR_GROUP_ROG_TUNABLE(nv_dynamic_boost, ATTR_NV_DYNAMIC_BOOST, ASUS_WMI_DEVID_NV_DYN_BOOST, + "Set the Nvidia dynamic boost limit"); +ATTR_GROUP_ROG_TUNABLE(nv_temp_target, ATTR_NV_TEMP_TARGET, ASUS_WMI_DEVID_NV_THERM_TARGET, + "Set the Nvidia max thermal limit"); +ATTR_GROUP_ROG_TUNABLE(nv_tgp, "nv_tgp", ASUS_WMI_DEVID_DGPU_SET_TGP, + "Set the additional TGP on top of the base TGP"); +ATTR_GROUP_INT_VALUE_ONLY_RO(nv_base_tgp, ATTR_NV_BASE_TGP, ASUS_WMI_DEVID_DGPU_BASE_TGP, + "Read the base TGP value"); + + +ATTR_GROUP_ENUM_INT_RO(charge_mode, "charge_mode", ASUS_WMI_DEVID_CHARGE_MODE, "0;1;2", + "Show the current mode of charging"); + +ATTR_GROUP_BOOL_RW(boot_sound, "boot_sound", ASUS_WMI_DEVID_BOOT_SOUND, + "Set the boot POST sound"); +ATTR_GROUP_BOOL_RW(mcu_powersave, "mcu_powersave", ASUS_WMI_DEVID_MCU_POWERSAVE, + "Set MCU powersaving mode"); +ATTR_GROUP_BOOL_RW(panel_od, "panel_overdrive", ASUS_WMI_DEVID_PANEL_OD, + "Set the panel refresh overdrive"); +ATTR_GROUP_BOOL_RW(panel_hd_mode, "panel_hd_mode", ASUS_WMI_DEVID_PANEL_HD, + "Set the panel HD mode to UHD<0> or FHD<1>"); +ATTR_GROUP_BOOL_RW(screen_auto_brightness, "screen_auto_brightness", + ASUS_WMI_DEVID_SCREEN_AUTO_BRIGHTNESS, + "Set the panel brightness to Off<0> or On<1>"); +ATTR_GROUP_BOOL_RO(egpu_connected, "egpu_connected", ASUS_WMI_DEVID_EGPU_CONNECTED, + "Show the eGPU connection status"); + +/* If an attribute does not require any special case handling add it here */ +static const struct asus_attr_group armoury_attr_groups[] = { + { &egpu_connected_attr_group, ASUS_WMI_DEVID_EGPU_CONNECTED }, + { &egpu_enable_attr_group, ASUS_WMI_DEVID_EGPU }, + { &dgpu_disable_attr_group, ASUS_WMI_DEVID_DGPU }, + { &apu_mem_attr_group, ASUS_WMI_DEVID_APU_MEM }, + { &cores_efficiency_attr_group, ASUS_WMI_DEVID_CORES_MAX }, + { &cores_performance_attr_group, ASUS_WMI_DEVID_CORES_MAX }, + + { &ppt_pl1_spl_attr_group, ASUS_WMI_DEVID_PPT_PL1_SPL }, + { &ppt_pl2_sppt_attr_group, ASUS_WMI_DEVID_PPT_PL2_SPPT }, + { &ppt_pl3_fppt_attr_group, ASUS_WMI_DEVID_PPT_FPPT }, + { &ppt_apu_sppt_attr_group, ASUS_WMI_DEVID_PPT_APU_SPPT }, + { &ppt_platform_sppt_attr_group, ASUS_WMI_DEVID_PPT_PLAT_SPPT }, + { &nv_dynamic_boost_attr_group, ASUS_WMI_DEVID_NV_DYN_BOOST }, + { &nv_temp_target_attr_group, ASUS_WMI_DEVID_NV_THERM_TARGET }, + { &nv_base_tgp_attr_group, ASUS_WMI_DEVID_DGPU_BASE_TGP }, + { &nv_tgp_attr_group, ASUS_WMI_DEVID_DGPU_SET_TGP }, + + { &charge_mode_attr_group, ASUS_WMI_DEVID_CHARGE_MODE }, + { &boot_sound_attr_group, ASUS_WMI_DEVID_BOOT_SOUND }, + { &mcu_powersave_attr_group, ASUS_WMI_DEVID_MCU_POWERSAVE }, + { &panel_od_attr_group, ASUS_WMI_DEVID_PANEL_OD }, + { &panel_hd_mode_attr_group, ASUS_WMI_DEVID_PANEL_HD }, +}; + +/** + * is_power_tunable_attr - Determines if an attribute is a power-related tunable + * @name: The name of the attribute to check + * + * This function checks if the given attribute name is related to power tuning. + * + * Return: true if the attribute is a power-related tunable, false otherwise + */ +static bool is_power_tunable_attr(const char *name) +{ + static const char * const power_tunable_attrs[] = { + ATTR_PPT_PL1_SPL, ATTR_PPT_PL2_SPPT, + ATTR_PPT_PL3_FPPT, ATTR_PPT_APU_SPPT, + ATTR_PPT_PLATFORM_SPPT, ATTR_NV_DYNAMIC_BOOST, + ATTR_NV_TEMP_TARGET, ATTR_NV_BASE_TGP, + ATTR_NV_TGP + }; + + for (int i = 0; i < ARRAY_SIZE(power_tunable_attrs); i++) { + if (!strcmp(name, power_tunable_attrs[i])) + return true; + } + + return false; +} + +/** + * has_valid_limit - Checks if a power-related attribute has a valid limit value + * @name: The name of the attribute to check + * @limits: Pointer to the power_limits structure containing limit values + * + * This function checks if a power-related attribute has a valid limit value. + * It returns false if limits is NULL or if the corresponding limit value is zero. + * + * Return: true if the attribute has a valid limit value, false otherwise + */ +static bool has_valid_limit(const char *name, const struct power_limits *limits) +{ + u32 limit_value = 0; + + if (!limits) + return false; + + if (!strcmp(name, ATTR_PPT_PL1_SPL)) + limit_value = limits->ppt_pl1_spl_max; + else if (!strcmp(name, ATTR_PPT_PL2_SPPT)) + limit_value = limits->ppt_pl2_sppt_max; + else if (!strcmp(name, ATTR_PPT_PL3_FPPT)) + limit_value = limits->ppt_pl3_fppt_max; + else if (!strcmp(name, ATTR_PPT_APU_SPPT)) + limit_value = limits->ppt_apu_sppt_max; + else if (!strcmp(name, ATTR_PPT_PLATFORM_SPPT)) + limit_value = limits->ppt_platform_sppt_max; + else if (!strcmp(name, ATTR_NV_DYNAMIC_BOOST)) + limit_value = limits->nv_dynamic_boost_max; + else if (!strcmp(name, ATTR_NV_TEMP_TARGET)) + limit_value = limits->nv_temp_target_max; + else if (!strcmp(name, ATTR_NV_BASE_TGP) || + !strcmp(name, ATTR_NV_TGP)) + limit_value = limits->nv_tgp_max; + + return limit_value > 0; +} + +static int asus_fw_attr_add(void) +{ + const struct power_limits *limits; + bool should_create; + const char *name; + int err, i; + + asus_armoury.fw_attr_dev = device_create(&firmware_attributes_class, NULL, MKDEV(0, 0), + NULL, "%s", DRIVER_NAME); + if (IS_ERR(asus_armoury.fw_attr_dev)) { + err = PTR_ERR(asus_armoury.fw_attr_dev); + goto fail_class_get; + } + + asus_armoury.fw_attr_kset = kset_create_and_add("attributes", NULL, + &asus_armoury.fw_attr_dev->kobj); + if (!asus_armoury.fw_attr_kset) { + err = -ENOMEM; + goto err_destroy_classdev; + } + + err = sysfs_create_file(&asus_armoury.fw_attr_kset->kobj, &pending_reboot.attr); + if (err) { + pr_err("Failed to create sysfs level attributes\n"); + goto err_destroy_kset; + } + + asus_armoury.mini_led_dev_id = 0; + if (asus_wmi_is_present(ASUS_WMI_DEVID_MINI_LED_MODE)) + asus_armoury.mini_led_dev_id = ASUS_WMI_DEVID_MINI_LED_MODE; + else if (asus_wmi_is_present(ASUS_WMI_DEVID_MINI_LED_MODE2)) + asus_armoury.mini_led_dev_id = ASUS_WMI_DEVID_MINI_LED_MODE2; + + if (asus_armoury.mini_led_dev_id) { + err = sysfs_create_group(&asus_armoury.fw_attr_kset->kobj, + &mini_led_mode_attr_group); + if (err) { + pr_err("Failed to create sysfs-group for mini_led\n"); + goto err_remove_file; + } + } + + asus_armoury.gpu_mux_dev_id = 0; + if (asus_wmi_is_present(ASUS_WMI_DEVID_GPU_MUX)) + asus_armoury.gpu_mux_dev_id = ASUS_WMI_DEVID_GPU_MUX; + else if (asus_wmi_is_present(ASUS_WMI_DEVID_GPU_MUX_VIVO)) + asus_armoury.gpu_mux_dev_id = ASUS_WMI_DEVID_GPU_MUX_VIVO; + + if (asus_armoury.gpu_mux_dev_id) { + err = sysfs_create_group(&asus_armoury.fw_attr_kset->kobj, + &gpu_mux_mode_attr_group); + if (err) { + pr_err("Failed to create sysfs-group for gpu_mux\n"); + goto err_remove_mini_led_group; + } + } + + for (i = 0; i < ARRAY_SIZE(armoury_attr_groups); i++) { + if (!asus_wmi_is_present(armoury_attr_groups[i].wmi_devid)) + continue; + + /* Always create by default, unless PPT is not present */ + should_create = true; + name = armoury_attr_groups[i].attr_group->name; + + /* Check if this is a power-related tunable requiring limits */ + if (asus_armoury.rog_tunables[1] && asus_armoury.rog_tunables[1]->power_limits && + is_power_tunable_attr(name)) { + limits = asus_armoury.rog_tunables[1]->power_limits; + /* Check only AC, if DC is not present then AC won't be either */ + should_create = has_valid_limit(name, limits); + if (!should_create) { + pr_debug( + "Missing max value on %s for tunable: %s\n", + dmi_get_system_info(DMI_BOARD_NAME), + name); + } + } + + if (should_create) { + err = sysfs_create_group( + &asus_armoury.fw_attr_kset->kobj, + armoury_attr_groups[i].attr_group); + if (err) { + pr_err("Failed to create sysfs-group for %s\n", + armoury_attr_groups[i].attr_group->name); + goto err_remove_groups; + } + } + } + + return 0; + +err_remove_groups: + while (--i >= 0) { + if (asus_wmi_is_present(armoury_attr_groups[i].wmi_devid)) + sysfs_remove_group(&asus_armoury.fw_attr_kset->kobj, + armoury_attr_groups[i].attr_group); + } + if (asus_armoury.gpu_mux_dev_id) + sysfs_remove_group(&asus_armoury.fw_attr_kset->kobj, &gpu_mux_mode_attr_group); +err_remove_mini_led_group: + if (asus_armoury.mini_led_dev_id) + sysfs_remove_group(&asus_armoury.fw_attr_kset->kobj, &mini_led_mode_attr_group); +err_remove_file: + sysfs_remove_file(&asus_armoury.fw_attr_kset->kobj, &pending_reboot.attr); +err_destroy_kset: + kset_unregister(asus_armoury.fw_attr_kset); +err_destroy_classdev: +fail_class_get: + device_destroy(&firmware_attributes_class, MKDEV(0, 0)); + return err; +} + +/* Init / exit ****************************************************************/ + +/* Set up the min/max and defaults for ROG tunables */ +static void init_rog_tunables(void) +{ + const struct power_limits *ac_limits, *dc_limits; + const struct power_data *power_data; + const struct dmi_system_id *dmi_id; + bool ac_initialized = false, dc_initialized = false; + + /* Match the system against the power_limits table */ + dmi_id = dmi_first_match(power_limits); + if (!dmi_id) { + pr_warn("No matching power limits found for this system\n"); + return; + } + + /* Get the power data for this system */ + power_data = dmi_id->driver_data; + if (!power_data) { + pr_info("No power data available for this system\n"); + return; + } + + /* Initialize AC power tunables */ + ac_limits = power_data->ac_data; + if (ac_limits) { + asus_armoury.rog_tunables[1] = + kzalloc(sizeof(struct rog_tunables), GFP_KERNEL); + if (!asus_armoury.rog_tunables[1]) + goto err_nomem; + + asus_armoury.rog_tunables[1]->power_limits = ac_limits; + + /* Set initial AC values */ + asus_armoury.rog_tunables[1]->ppt_pl1_spl = + ac_limits->ppt_pl1_spl_def ? + ac_limits->ppt_pl1_spl_def : + ac_limits->ppt_pl1_spl_max; + + asus_armoury.rog_tunables[1]->ppt_pl2_sppt = + ac_limits->ppt_pl2_sppt_def ? + ac_limits->ppt_pl2_sppt_def : + ac_limits->ppt_pl2_sppt_max; + + asus_armoury.rog_tunables[1]->ppt_pl3_fppt = + ac_limits->ppt_pl3_fppt_def ? + ac_limits->ppt_pl3_fppt_def : + ac_limits->ppt_pl3_fppt_max; + + asus_armoury.rog_tunables[1]->ppt_apu_sppt = + ac_limits->ppt_apu_sppt_def ? + ac_limits->ppt_apu_sppt_def : + ac_limits->ppt_apu_sppt_max; + + asus_armoury.rog_tunables[1]->ppt_platform_sppt = + ac_limits->ppt_platform_sppt_def ? + ac_limits->ppt_platform_sppt_def : + ac_limits->ppt_platform_sppt_max; + + asus_armoury.rog_tunables[1]->nv_dynamic_boost = + ac_limits->nv_dynamic_boost_max; + asus_armoury.rog_tunables[1]->nv_temp_target = + ac_limits->nv_temp_target_max; + asus_armoury.rog_tunables[1]->nv_tgp = ac_limits->nv_tgp_max; + + ac_initialized = true; + pr_debug("AC power limits initialized for %s\n", dmi_id->matches[0].substr); + } + + /* Initialize DC power tunables */ + dc_limits = power_data->dc_data; + if (dc_limits) { + asus_armoury.rog_tunables[0] = + kzalloc(sizeof(struct rog_tunables), GFP_KERNEL); + if (!asus_armoury.rog_tunables[0]) { + if (ac_initialized) + kfree(asus_armoury.rog_tunables[1]); + goto err_nomem; + } + + asus_armoury.rog_tunables[0]->power_limits = dc_limits; + + /* Set initial DC values */ + asus_armoury.rog_tunables[0]->ppt_pl1_spl = + dc_limits->ppt_pl1_spl_def ? + dc_limits->ppt_pl1_spl_def : + dc_limits->ppt_pl1_spl_max; + + asus_armoury.rog_tunables[0]->ppt_pl2_sppt = + dc_limits->ppt_pl2_sppt_def ? + dc_limits->ppt_pl2_sppt_def : + dc_limits->ppt_pl2_sppt_max; + + asus_armoury.rog_tunables[0]->ppt_pl3_fppt = + dc_limits->ppt_pl3_fppt_def ? + dc_limits->ppt_pl3_fppt_def : + dc_limits->ppt_pl3_fppt_max; + + asus_armoury.rog_tunables[0]->ppt_apu_sppt = + dc_limits->ppt_apu_sppt_def ? + dc_limits->ppt_apu_sppt_def : + dc_limits->ppt_apu_sppt_max; + + asus_armoury.rog_tunables[0]->ppt_platform_sppt = + dc_limits->ppt_platform_sppt_def ? + dc_limits->ppt_platform_sppt_def : + dc_limits->ppt_platform_sppt_max; + + asus_armoury.rog_tunables[0]->nv_dynamic_boost = + dc_limits->nv_dynamic_boost_max; + asus_armoury.rog_tunables[0]->nv_temp_target = + dc_limits->nv_temp_target_max; + asus_armoury.rog_tunables[0]->nv_tgp = dc_limits->nv_tgp_max; + + dc_initialized = true; + pr_debug("DC power limits initialized for %s\n", dmi_id->matches[0].substr); + } + + if (!ac_initialized) + pr_debug("No AC PPT limits defined\n"); + + if (!dc_initialized) + pr_debug("No DC PPT limits defined\n"); + + return; + +err_nomem: + pr_err("Failed to allocate memory for tunables\n"); +} + +static int __init asus_fw_init(void) +{ + char *wmi_uid; + int err; + + wmi_uid = wmi_get_acpi_device_uid(ASUS_WMI_MGMT_GUID); + if (!wmi_uid) + return -ENODEV; + + /* + * if equal to "ASUSWMI" then it's DCTS that can't be used for this + * driver, DSTS is required. + */ + if (!strcmp(wmi_uid, ASUS_ACPI_UID_ASUSWMI)) + return -ENODEV; + + if (asus_wmi_is_present(ASUS_WMI_DEVID_CORES_MAX)) { + asus_armoury.cpu_cores = kzalloc(sizeof(struct cpu_cores), GFP_KERNEL); + if (!asus_armoury.cpu_cores) + return -ENOMEM; + + err = init_max_cpu_cores(); + if (err) { + kfree(asus_armoury.cpu_cores); + pr_err("Could not initialise CPU core control %d\n", err); + return err; + } + } + + init_rog_tunables(); + + /* Must always be last step to ensure data is available */ + return asus_fw_attr_add(); +} + +static void __exit asus_fw_exit(void) +{ + sysfs_remove_file(&asus_armoury.fw_attr_kset->kobj, &pending_reboot.attr); + kset_unregister(asus_armoury.fw_attr_kset); + device_destroy(&firmware_attributes_class, MKDEV(0, 0)); + + kfree(asus_armoury.rog_tunables[0]); + kfree(asus_armoury.rog_tunables[1]); +} + +module_init(asus_fw_init); +module_exit(asus_fw_exit); + +MODULE_IMPORT_NS("ASUS_WMI"); +MODULE_AUTHOR("Luke Jones "); +MODULE_DESCRIPTION("ASUS BIOS Configuration Driver"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("wmi:" ASUS_NB_WMI_EVENT_GUID); diff --git a/drivers/platform/x86/asus-armoury.h b/drivers/platform/x86/asus-armoury.h new file mode 100644 index 000000000000..438768ea14cc --- /dev/null +++ b/drivers/platform/x86/asus-armoury.h @@ -0,0 +1,1278 @@ +/* SPDX-License-Identifier: GPL-2.0 + * + * Definitions for kernel modules using asus-armoury driver + * + * Copyright (c) 2024 Luke Jones + */ + +#ifndef _ASUS_ARMOURY_H_ +#define _ASUS_ARMOURY_H_ + +#include +#include +#include + +#define DRIVER_NAME "asus-armoury" + +#define __ASUS_ATTR_RO(_func, _name) \ + { \ + .attr = { .name = __stringify(_name), .mode = 0444 }, \ + .show = _func##_##_name##_show, \ + } + +#define __ASUS_ATTR_RO_AS(_name, _show) \ + { \ + .attr = { .name = __stringify(_name), .mode = 0444 }, \ + .show = _show, \ + } + +#define __ASUS_ATTR_RW(_func, _name) \ + __ATTR(_name, 0644, _func##_##_name##_show, _func##_##_name##_store) + +#define __WMI_STORE_INT(_attr, _min, _max, _wmi) \ + static ssize_t _attr##_store(struct kobject *kobj, \ + struct kobj_attribute *attr, \ + const char *buf, size_t count) \ + { \ + return attr_uint_store(kobj, attr, buf, count, _min, \ + _max, NULL, _wmi); \ + } + +#define WMI_SHOW_INT(_attr, _fmt, _wmi) \ + static ssize_t _attr##_show(struct kobject *kobj, \ + struct kobj_attribute *attr, char *buf) \ + { \ + u32 result; \ + int err; \ + \ + err = asus_wmi_get_devstate_dsts(_wmi, &result); \ + if (err) \ + return err; \ + return sysfs_emit(buf, _fmt, \ + result & ~ASUS_WMI_DSTS_PRESENCE_BIT); \ + } + +/* Create functions and attributes for use in other macros or on their own */ + +/* Shows a formatted static variable */ +#define __ATTR_SHOW_FMT(_prop, _attrname, _fmt, _val) \ + static ssize_t _attrname##_##_prop##_show( \ + struct kobject *kobj, struct kobj_attribute *attr, char *buf) \ + { \ + return sysfs_emit(buf, _fmt, _val); \ + } \ + static struct kobj_attribute attr_##_attrname##_##_prop = \ + __ASUS_ATTR_RO(_attrname, _prop) + +#define __ATTR_RO_INT_GROUP_ENUM(_attrname, _wmi, _fsname, _possible, _dispname)\ + WMI_SHOW_INT(_attrname##_current_value, "%d\n", _wmi); \ + static struct kobj_attribute attr_##_attrname##_current_value = \ + __ASUS_ATTR_RO(_attrname, current_value); \ + __ATTR_SHOW_FMT(display_name, _attrname, "%s\n", _dispname); \ + __ATTR_SHOW_FMT(possible_values, _attrname, "%s\n", _possible); \ + static struct kobj_attribute attr_##_attrname##_type = \ + __ASUS_ATTR_RO_AS(type, enum_type_show); \ + static struct attribute *_attrname##_attrs[] = { \ + &attr_##_attrname##_current_value.attr, \ + &attr_##_attrname##_display_name.attr, \ + &attr_##_attrname##_possible_values.attr, \ + &attr_##_attrname##_type.attr, \ + NULL \ + }; \ + static const struct attribute_group _attrname##_attr_group = { \ + .name = _fsname, .attrs = _attrname##_attrs \ + } + +#define __ATTR_RW_INT_GROUP_ENUM(_attrname, _minv, _maxv, _wmi, _fsname,\ + _possible, _dispname) \ + __WMI_STORE_INT(_attrname##_current_value, _minv, _maxv, _wmi); \ + WMI_SHOW_INT(_attrname##_current_value, "%d\n", _wmi); \ + static struct kobj_attribute attr_##_attrname##_current_value = \ + __ASUS_ATTR_RW(_attrname, current_value); \ + __ATTR_SHOW_FMT(display_name, _attrname, "%s\n", _dispname); \ + __ATTR_SHOW_FMT(possible_values, _attrname, "%s\n", _possible); \ + static struct kobj_attribute attr_##_attrname##_type = \ + __ASUS_ATTR_RO_AS(type, enum_type_show); \ + static struct attribute *_attrname##_attrs[] = { \ + &attr_##_attrname##_current_value.attr, \ + &attr_##_attrname##_display_name.attr, \ + &attr_##_attrname##_possible_values.attr, \ + &attr_##_attrname##_type.attr, \ + NULL \ + }; \ + static const struct attribute_group _attrname##_attr_group = { \ + .name = _fsname, .attrs = _attrname##_attrs \ + } + +/* Boolean style enumeration, base macro. Requires adding show/store */ +#define __ATTR_GROUP_ENUM(_attrname, _fsname, _possible, _dispname) \ + __ATTR_SHOW_FMT(display_name, _attrname, "%s\n", _dispname); \ + __ATTR_SHOW_FMT(possible_values, _attrname, "%s\n", _possible); \ + static struct kobj_attribute attr_##_attrname##_type = \ + __ASUS_ATTR_RO_AS(type, enum_type_show); \ + static struct attribute *_attrname##_attrs[] = { \ + &attr_##_attrname##_current_value.attr, \ + &attr_##_attrname##_display_name.attr, \ + &attr_##_attrname##_possible_values.attr, \ + &attr_##_attrname##_type.attr, \ + NULL \ + }; \ + static const struct attribute_group _attrname##_attr_group = { \ + .name = _fsname, .attrs = _attrname##_attrs \ + } + +#define ATTR_GROUP_BOOL_RO(_attrname, _fsname, _wmi, _dispname) \ + __ATTR_RO_INT_GROUP_ENUM(_attrname, _wmi, _fsname, "0;1", _dispname) + + +#define ATTR_GROUP_BOOL_RW(_attrname, _fsname, _wmi, _dispname) \ + __ATTR_RW_INT_GROUP_ENUM(_attrname, 0, 1, _wmi, _fsname, "0;1", _dispname) + +#define ATTR_GROUP_ENUM_INT_RO(_attrname, _fsname, _wmi, _possible, _dispname) \ + __ATTR_RO_INT_GROUP_ENUM(_attrname, _wmi, _fsname, _possible, _dispname) + +/* + * Requires _current_value_show(), _current_value_show() + */ +#define ATTR_GROUP_BOOL_CUSTOM(_attrname, _fsname, _dispname) \ + static struct kobj_attribute attr_##_attrname##_current_value = \ + __ASUS_ATTR_RW(_attrname, current_value); \ + __ATTR_GROUP_ENUM(_attrname, _fsname, "0;1", _dispname) + +/* + * Requires _current_value_show(), _current_value_show() + * and _possible_values_show() + */ +#define ATTR_GROUP_ENUM_CUSTOM(_attrname, _fsname, _dispname) \ + __ATTR_SHOW_FMT(display_name, _attrname, "%s\n", _dispname); \ + static struct kobj_attribute attr_##_attrname##_current_value = \ + __ASUS_ATTR_RW(_attrname, current_value); \ + static struct kobj_attribute attr_##_attrname##_possible_values = \ + __ASUS_ATTR_RO(_attrname, possible_values); \ + static struct kobj_attribute attr_##_attrname##_type = \ + __ASUS_ATTR_RO_AS(type, enum_type_show); \ + static struct attribute *_attrname##_attrs[] = { \ + &attr_##_attrname##_current_value.attr, \ + &attr_##_attrname##_display_name.attr, \ + &attr_##_attrname##_possible_values.attr, \ + &attr_##_attrname##_type.attr, \ + NULL \ + }; \ + static const struct attribute_group _attrname##_attr_group = { \ + .name = _fsname, .attrs = _attrname##_attrs \ + } + +/* CPU core attributes need a little different in setup */ +#define ATTR_GROUP_CORES_RW(_attrname, _fsname, _dispname) \ + __ATTR_SHOW_FMT(scalar_increment, _attrname, "%d\n", 1); \ + __ATTR_SHOW_FMT(display_name, _attrname, "%s\n", _dispname); \ + static struct kobj_attribute attr_##_attrname##_current_value = \ + __ASUS_ATTR_RW(_attrname, current_value); \ + static struct kobj_attribute attr_##_attrname##_default_value = \ + __ASUS_ATTR_RO(_attrname, default_value); \ + static struct kobj_attribute attr_##_attrname##_min_value = \ + __ASUS_ATTR_RO(_attrname, min_value); \ + static struct kobj_attribute attr_##_attrname##_max_value = \ + __ASUS_ATTR_RO(_attrname, max_value); \ + static struct kobj_attribute attr_##_attrname##_type = \ + __ASUS_ATTR_RO_AS(type, int_type_show); \ + static struct attribute *_attrname##_attrs[] = { \ + &attr_##_attrname##_current_value.attr, \ + &attr_##_attrname##_default_value.attr, \ + &attr_##_attrname##_min_value.attr, \ + &attr_##_attrname##_max_value.attr, \ + &attr_##_attrname##_scalar_increment.attr, \ + &attr_##_attrname##_display_name.attr, \ + &attr_##_attrname##_type.attr, \ + NULL \ + }; \ + static const struct attribute_group _attrname##_attr_group = { \ + .name = _fsname, .attrs = _attrname##_attrs \ + } + +#define ATTR_GROUP_INT_VALUE_ONLY_RO(_attrname, _fsname, _wmi, _dispname) \ + WMI_SHOW_INT(_attrname##_current_value, "%d\n", _wmi); \ + static struct kobj_attribute attr_##_attrname##_current_value = \ + __ASUS_ATTR_RO(_attrname, current_value); \ + __ATTR_SHOW_FMT(display_name, _attrname, "%s\n", _dispname); \ + static struct kobj_attribute attr_##_attrname##_type = \ + __ASUS_ATTR_RO_AS(type, int_type_show); \ + static struct attribute *_attrname##_attrs[] = { \ + &attr_##_attrname##_current_value.attr, \ + &attr_##_attrname##_display_name.attr, \ + &attr_##_attrname##_type.attr, NULL \ + }; \ + static const struct attribute_group _attrname##_attr_group = { \ + .name = _fsname, .attrs = _attrname##_attrs \ + } + +/* + * ROG PPT attributes need a little different in setup as they + * require rog_tunables members. + */ + +#define __ROG_TUNABLE_SHOW(_prop, _attrname, _val) \ + static ssize_t _attrname##_##_prop##_show( \ + struct kobject *kobj, struct kobj_attribute *attr, char *buf) \ + { \ + struct rog_tunables *tunables = get_current_tunables(); \ + \ + if (!tunables || !tunables->power_limits) \ + return -ENODEV; \ + \ + return sysfs_emit(buf, "%d\n", tunables->power_limits->_val); \ + } \ + static struct kobj_attribute attr_##_attrname##_##_prop = \ + __ASUS_ATTR_RO(_attrname, _prop) + +#define __ROG_TUNABLE_SHOW_DEFAULT(_attrname) \ + static ssize_t _attrname##_default_value_show( \ + struct kobject *kobj, struct kobj_attribute *attr, char *buf) \ + { \ + struct rog_tunables *tunables = get_current_tunables(); \ + \ + if (!tunables || !tunables->power_limits) \ + return -ENODEV; \ + \ + return sysfs_emit( \ + buf, "%d\n", \ + tunables->power_limits->_attrname##_def ? \ + tunables->power_limits->_attrname##_def : \ + tunables->power_limits->_attrname##_max); \ + } \ + static struct kobj_attribute attr_##_attrname##_default_value = \ + __ASUS_ATTR_RO(_attrname, default_value) + +#define __ROG_TUNABLE_RW(_attr, _wmi) \ + static ssize_t _attr##_current_value_store( \ + struct kobject *kobj, struct kobj_attribute *attr, \ + const char *buf, size_t count) \ + { \ + struct rog_tunables *tunables = get_current_tunables(); \ + \ + if (!tunables || !tunables->power_limits) \ + return -ENODEV; \ + \ + return attr_uint_store(kobj, attr, buf, count, \ + tunables->power_limits->_attr##_min, \ + tunables->power_limits->_attr##_max, \ + &tunables->_attr, _wmi); \ + } \ + static ssize_t _attr##_current_value_show( \ + struct kobject *kobj, struct kobj_attribute *attr, char *buf) \ + { \ + struct rog_tunables *tunables = get_current_tunables(); \ + \ + if (!tunables) \ + return -ENODEV; \ + \ + return sysfs_emit(buf, "%u\n", tunables->_attr); \ + } \ + static struct kobj_attribute attr_##_attr##_current_value = \ + __ASUS_ATTR_RW(_attr, current_value) + +#define ATTR_GROUP_ROG_TUNABLE(_attrname, _fsname, _wmi, _dispname) \ + __ROG_TUNABLE_RW(_attrname, _wmi); \ + __ROG_TUNABLE_SHOW_DEFAULT(_attrname); \ + __ROG_TUNABLE_SHOW(min_value, _attrname, _attrname##_min); \ + __ROG_TUNABLE_SHOW(max_value, _attrname, _attrname##_max); \ + __ATTR_SHOW_FMT(scalar_increment, _attrname, "%d\n", 1); \ + __ATTR_SHOW_FMT(display_name, _attrname, "%s\n", _dispname); \ + static struct kobj_attribute attr_##_attrname##_type = \ + __ASUS_ATTR_RO_AS(type, int_type_show); \ + static struct attribute *_attrname##_attrs[] = { \ + &attr_##_attrname##_current_value.attr, \ + &attr_##_attrname##_default_value.attr, \ + &attr_##_attrname##_min_value.attr, \ + &attr_##_attrname##_max_value.attr, \ + &attr_##_attrname##_scalar_increment.attr, \ + &attr_##_attrname##_display_name.attr, \ + &attr_##_attrname##_type.attr, \ + NULL \ + }; \ + static const struct attribute_group _attrname##_attr_group = { \ + .name = _fsname, .attrs = _attrname##_attrs \ + } + +/* Default is always the maximum value unless *_def is specified */ +struct power_limits { + u8 ppt_pl1_spl_min; + u8 ppt_pl1_spl_def; + u8 ppt_pl1_spl_max; + u8 ppt_pl2_sppt_min; + u8 ppt_pl2_sppt_def; + u8 ppt_pl2_sppt_max; + u8 ppt_pl3_fppt_min; + u8 ppt_pl3_fppt_def; + u8 ppt_pl3_fppt_max; + u8 ppt_apu_sppt_min; + u8 ppt_apu_sppt_def; + u8 ppt_apu_sppt_max; + u8 ppt_platform_sppt_min; + u8 ppt_platform_sppt_def; + u8 ppt_platform_sppt_max; + /* Nvidia GPU specific, default is always max */ + u8 nv_dynamic_boost_def; // unused. exists for macro + u8 nv_dynamic_boost_min; + u8 nv_dynamic_boost_max; + u8 nv_temp_target_def; // unused. exists for macro + u8 nv_temp_target_min; + u8 nv_temp_target_max; + u8 nv_tgp_def; // unused. exists for macro + u8 nv_tgp_min; + u8 nv_tgp_max; +}; + +struct power_data { + const struct power_limits *ac_data; + const struct power_limits *dc_data; + bool requires_fan_curve; +}; + +/* + * For each avilable attribute there must be a min and a max. + * _def is not required and will be assumed to be default == max if missing. + */ +static const struct dmi_system_id power_limits[] = { + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "FA401W"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_max = 80, + .ppt_pl2_sppt_min = 35, + .ppt_pl2_sppt_max = 80, + .ppt_pl3_fppt_min = 35, + .ppt_pl3_fppt_max = 80, + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 25, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + .nv_tgp_min = 55, + .nv_tgp_max = 75, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 25, + .ppt_pl1_spl_max = 30, + .ppt_pl2_sppt_min = 31, + .ppt_pl2_sppt_max = 44, + .ppt_pl3_fppt_min = 45, + .ppt_pl3_fppt_max = 65, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "FA507N"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_max = 80, + .ppt_pl2_sppt_min = 35, + .ppt_pl2_sppt_max = 80, + .ppt_pl3_fppt_min = 35, + .ppt_pl3_fppt_max = 80, + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 25, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_def = 45, + .ppt_pl1_spl_max = 65, + .ppt_pl2_sppt_min = 35, + .ppt_pl2_sppt_def = 54, + .ppt_pl2_sppt_max = 65, + .ppt_pl3_fppt_min = 35, + .ppt_pl3_fppt_max = 65, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + } + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "FA507R"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_max = 80, + .ppt_pl2_sppt_min = 25, + .ppt_pl2_sppt_max = 80, + .ppt_pl3_fppt_min = 35, + .ppt_pl3_fppt_max = 80 + }, + .dc_data = NULL + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "FA507X"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_max = 80, + .ppt_pl2_sppt_min = 35, + .ppt_pl2_sppt_max = 80, + .ppt_pl3_fppt_min = 35, + .ppt_pl3_fppt_max = 80, + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 20, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + .nv_tgp_min = 55, + .nv_tgp_max = 85, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_def = 45, + .ppt_pl1_spl_max = 65, + .ppt_pl2_sppt_min = 35, + .ppt_pl2_sppt_def = 54, + .ppt_pl2_sppt_max = 65, + .ppt_pl3_fppt_min = 35, + .ppt_pl3_fppt_max = 65, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + } + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "FA507Z"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 28, + .ppt_pl1_spl_max = 65, + .ppt_pl2_sppt_min = 28, + .ppt_pl2_sppt_max = 105, + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 15, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + .nv_tgp_min = 55, + .nv_tgp_max = 85, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 25, + .ppt_pl1_spl_max = 45, + .ppt_pl2_sppt_min = 35, + .ppt_pl2_sppt_max = 60, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + } + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "FA607P"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 30, + .ppt_pl1_spl_def = 100, + .ppt_pl1_spl_max = 135, + .ppt_pl2_sppt_min = 30, + .ppt_pl2_sppt_def = 115, + .ppt_pl2_sppt_max = 135, + .ppt_pl3_fppt_min = 30, + .ppt_pl3_fppt_max = 135, + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 25, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + .nv_tgp_min = 55, + .nv_tgp_max = 115, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 25, + .ppt_pl1_spl_def = 45, + .ppt_pl1_spl_max = 80, + .ppt_pl2_sppt_min = 25, + .ppt_pl2_sppt_def = 60, + .ppt_pl2_sppt_max = 80, + .ppt_pl3_fppt_min = 25, + .ppt_pl3_fppt_max = 80, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + } + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "FA617NS"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_apu_sppt_min = 15, + .ppt_apu_sppt_max = 80, + .ppt_platform_sppt_min = 30, + .ppt_platform_sppt_max = 120 + }, + .dc_data = &(struct power_limits) { + .ppt_apu_sppt_min = 25, + .ppt_apu_sppt_max = 35, + .ppt_platform_sppt_min = 45, + .ppt_platform_sppt_max = 100 + } + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "FA617NT"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_apu_sppt_min = 15, + .ppt_apu_sppt_max = 80, + .ppt_platform_sppt_min = 30, + .ppt_platform_sppt_max = 115 + }, + .dc_data = &(struct power_limits) { + .ppt_apu_sppt_min = 15, + .ppt_apu_sppt_max = 45, + .ppt_platform_sppt_min = 30, + .ppt_platform_sppt_max = 50 + } + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "FA617XS"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_apu_sppt_min = 15, + .ppt_apu_sppt_max = 80, + .ppt_platform_sppt_min = 30, + .ppt_platform_sppt_max = 120, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .dc_data = &(struct power_limits) { + .ppt_apu_sppt_min = 25, + .ppt_apu_sppt_max = 35, + .ppt_platform_sppt_min = 45, + .ppt_platform_sppt_max = 100, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + } + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "FX507Z"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 28, + .ppt_pl1_spl_max = 90, + .ppt_pl2_sppt_min = 28, + .ppt_pl2_sppt_max = 135, + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 15, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 25, + .ppt_pl1_spl_max = 45, + .ppt_pl2_sppt_min = 35, + .ppt_pl2_sppt_max = 60, + }, + .requires_fan_curve = true, + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "GA401Q"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_max = 80, + .ppt_pl2_sppt_min = 15, + .ppt_pl2_sppt_max = 80, + }, + .dc_data = NULL + }, + }, + { + .matches = { + // This model is full AMD. No Nvidia dGPU. + DMI_MATCH(DMI_BOARD_NAME, "GA402R"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_apu_sppt_min = 15, + .ppt_apu_sppt_max = 80, + .ppt_platform_sppt_min = 30, + .ppt_platform_sppt_max = 115, + }, + .dc_data = &(struct power_limits) { + .ppt_apu_sppt_min = 25, + .ppt_apu_sppt_def = 30, + .ppt_apu_sppt_max = 45, + .ppt_platform_sppt_min = 40, + .ppt_platform_sppt_max = 60, + } + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "GA402X"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_def = 35, + .ppt_pl1_spl_max = 80, + .ppt_pl2_sppt_min = 25, + .ppt_pl2_sppt_def = 65, + .ppt_pl2_sppt_max = 80, + .ppt_pl3_fppt_min = 35, + .ppt_pl3_fppt_max = 80, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_max = 35, + .ppt_pl2_sppt_min = 25, + .ppt_pl2_sppt_max = 35, + .ppt_pl3_fppt_min = 35, + .ppt_pl3_fppt_max = 65, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .requires_fan_curve = true, + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "GA403U"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_max = 80, + .ppt_pl2_sppt_min = 25, + .ppt_pl2_sppt_max = 80, + .ppt_pl3_fppt_min = 35, + .ppt_pl3_fppt_max = 80, + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 25, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + .nv_tgp_min = 55, + .nv_tgp_max = 65, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_max = 35, + .ppt_pl2_sppt_min = 25, + .ppt_pl2_sppt_max = 35, + .ppt_pl3_fppt_min = 35, + .ppt_pl3_fppt_max = 65, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .requires_fan_curve = true, + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "GA503R"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_def = 35, + .ppt_pl1_spl_max = 80, + .ppt_pl2_sppt_min = 35, + .ppt_pl2_sppt_def = 65, + .ppt_pl2_sppt_max = 80, + .ppt_pl3_fppt_min = 35, + .ppt_pl3_fppt_max = 80, + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 20, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_def = 25, + .ppt_pl1_spl_max = 65, + .ppt_pl2_sppt_min = 35, + .ppt_pl2_sppt_def = 54, + .ppt_pl2_sppt_max = 60, + .ppt_pl3_fppt_min = 35, + .ppt_pl3_fppt_max = 65 + } + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "GA605W"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_max = 80, + .ppt_pl2_sppt_min = 35, + .ppt_pl2_sppt_max = 80, + .ppt_pl3_fppt_min = 35, + .ppt_pl3_fppt_max = 80, + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 20, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + .nv_tgp_min = 55, + .nv_tgp_max = 85, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 25, + .ppt_pl1_spl_max = 35, + .ppt_pl2_sppt_min = 31, + .ppt_pl2_sppt_max = 44, + .ppt_pl3_fppt_min = 45, + .ppt_pl3_fppt_max = 65, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .requires_fan_curve = true, + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "GU603Z"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 25, + .ppt_pl1_spl_max = 60, + .ppt_pl2_sppt_min = 25, + .ppt_pl2_sppt_max = 135, + /* Only allowed in AC mode */ + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 20, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 25, + .ppt_pl1_spl_max = 40, + .ppt_pl2_sppt_min = 25, + .ppt_pl2_sppt_max = 40, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + } + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "GU604V"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 65, + .ppt_pl1_spl_max = 120, + .ppt_pl2_sppt_min = 65, + .ppt_pl2_sppt_max = 150, + /* Only allowed in AC mode */ + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 25, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 25, + .ppt_pl1_spl_max = 40, + .ppt_pl2_sppt_min = 35, + .ppt_pl2_sppt_def = 40, + .ppt_pl2_sppt_max = 60, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + } + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "GU605M"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 28, + .ppt_pl1_spl_max = 90, + .ppt_pl2_sppt_min = 28, + .ppt_pl2_sppt_max = 135, + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 20, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 25, + .ppt_pl1_spl_max = 35, + .ppt_pl2_sppt_min = 38, + .ppt_pl2_sppt_max = 53, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .requires_fan_curve = true, + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "GV301Q"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_max = 45, + .ppt_pl2_sppt_min = 65, + .ppt_pl2_sppt_max = 80, + }, + .dc_data = NULL + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "GV301R"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_max = 45, + .ppt_pl2_sppt_min = 25, + .ppt_pl2_sppt_max = 54, + .ppt_pl3_fppt_min = 35, + .ppt_pl3_fppt_max = 65, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_max = 35, + .ppt_pl2_sppt_min = 25, + .ppt_pl2_sppt_max = 35, + .ppt_pl3_fppt_min = 35, + .ppt_pl3_fppt_max = 65, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + } + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "GV601R"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_def = 35, + .ppt_pl1_spl_max = 90, + .ppt_pl2_sppt_min = 35, + .ppt_pl2_sppt_def = 54, + .ppt_pl2_sppt_max = 100, + .ppt_pl3_fppt_min = 35, + .ppt_pl3_fppt_def = 80, + .ppt_pl3_fppt_max = 125, + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 25, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_def = 28, + .ppt_pl1_spl_max = 65, + .ppt_pl2_sppt_min = 35, + .ppt_pl2_sppt_def = 54, + .ppt_pl2_sppt_def = 40, + .ppt_pl2_sppt_max = 60, + .ppt_pl3_fppt_min = 35, + .ppt_pl3_fppt_def = 80, + .ppt_pl3_fppt_max = 65, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + } + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "GV601V"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 28, + .ppt_pl1_spl_def = 100, + .ppt_pl1_spl_max = 110, + .ppt_pl2_sppt_min = 28, + .ppt_pl2_sppt_max = 135, + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 20, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 25, + .ppt_pl1_spl_max = 40, + .ppt_pl2_sppt_min = 35, + .ppt_pl2_sppt_def = 40, + .ppt_pl2_sppt_max = 60, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + } + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "GX650P"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_def = 110, + .ppt_pl1_spl_max = 130, + .ppt_pl2_sppt_min = 35, + .ppt_pl2_sppt_def = 125, + .ppt_pl2_sppt_max = 130, + .ppt_pl3_fppt_min = 35, + .ppt_pl3_fppt_def = 125, + .ppt_pl3_fppt_max = 135, + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 25, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_def = 25, + .ppt_pl1_spl_max = 65, + .ppt_pl2_sppt_min = 35, + .ppt_pl2_sppt_def = 35, + .ppt_pl2_sppt_max = 65, + .ppt_pl3_fppt_min = 35, + .ppt_pl3_fppt_def = 42, + .ppt_pl3_fppt_max = 65, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + } + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "G513I"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + /* Yes this laptop is very limited */ + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_max = 80, + .ppt_pl2_sppt_min = 15, + .ppt_pl2_sppt_max = 80, + }, + .dc_data = NULL, + .requires_fan_curve = true, + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "G513QM"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + /* Yes this laptop is very limited */ + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_max = 100, + .ppt_pl2_sppt_min = 15, + .ppt_pl2_sppt_max = 190, + }, + .dc_data = NULL, + .requires_fan_curve = true, + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "G513R"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 35, + .ppt_pl1_spl_max = 90, + .ppt_pl2_sppt_min = 54, + .ppt_pl2_sppt_max = 100, + .ppt_pl3_fppt_min = 54, + .ppt_pl3_fppt_max = 125, + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 25, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 28, + .ppt_pl1_spl_max = 50, + .ppt_pl2_sppt_min = 28, + .ppt_pl2_sppt_max = 50, + .ppt_pl3_fppt_min = 28, + .ppt_pl3_fppt_max = 65, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .requires_fan_curve = true, + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "G614J"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 28, + .ppt_pl1_spl_max = 140, + .ppt_pl2_sppt_min = 28, + .ppt_pl2_sppt_max = 175, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 25, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 25, + .ppt_pl1_spl_max = 55, + .ppt_pl2_sppt_min = 25, + .ppt_pl2_sppt_max = 70, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .requires_fan_curve = true, + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "G634J"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 28, + .ppt_pl1_spl_max = 140, + .ppt_pl2_sppt_min = 28, + .ppt_pl2_sppt_max = 175, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 25, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 25, + .ppt_pl1_spl_max = 55, + .ppt_pl2_sppt_min = 25, + .ppt_pl2_sppt_max = 70, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .requires_fan_curve = true, + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "G733C"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 28, + .ppt_pl1_spl_max = 170, + .ppt_pl2_sppt_min = 28, + .ppt_pl2_sppt_max = 175, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 25, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 28, + .ppt_pl1_spl_max = 35, + .ppt_pl2_sppt_min = 28, + .ppt_pl2_sppt_max = 35, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .requires_fan_curve = true, + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "G733P"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 30, + .ppt_pl1_spl_def = 100, + .ppt_pl1_spl_max = 130, + .ppt_pl2_sppt_min = 65, + .ppt_pl2_sppt_def = 125, + .ppt_pl2_sppt_max = 130, + .ppt_pl3_fppt_min = 65, + .ppt_pl3_fppt_def = 125, + .ppt_pl3_fppt_max = 130, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 25, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 25, + .ppt_pl1_spl_max = 65, + .ppt_pl2_sppt_min = 25, + .ppt_pl2_sppt_max = 65, + .ppt_pl3_fppt_min = 35, + .ppt_pl3_fppt_max = 75, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .requires_fan_curve = true, + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "G814J"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 28, + .ppt_pl1_spl_max = 140, + .ppt_pl2_sppt_min = 28, + .ppt_pl2_sppt_max = 140, + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 25, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 25, + .ppt_pl1_spl_max = 55, + .ppt_pl2_sppt_min = 25, + .ppt_pl2_sppt_max = 70, + }, + .requires_fan_curve = true, + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "G834J"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 28, + .ppt_pl1_spl_max = 140, + .ppt_pl2_sppt_min = 28, + .ppt_pl2_sppt_max = 175, + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 25, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 25, + .ppt_pl1_spl_max = 55, + .ppt_pl2_sppt_min = 25, + .ppt_pl2_sppt_max = 70, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + }, + .requires_fan_curve = true, + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "H7606W"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 15, + .ppt_pl1_spl_max = 80, + .ppt_pl2_sppt_min = 35, + .ppt_pl2_sppt_max = 80, + .ppt_pl3_fppt_min = 35, + .ppt_pl3_fppt_max = 80, + .nv_dynamic_boost_min = 5, + .nv_dynamic_boost_max = 20, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + .nv_tgp_min = 55, + .nv_tgp_max = 85, + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 25, + .ppt_pl1_spl_max = 35, + .ppt_pl2_sppt_min = 31, + .ppt_pl2_sppt_max = 44, + .ppt_pl3_fppt_min = 45, + .ppt_pl3_fppt_max = 65, + .nv_temp_target_min = 75, + .nv_temp_target_max = 87, + } + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "RC71"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 7, + .ppt_pl1_spl_max = 30, + .ppt_pl2_sppt_min = 15, + .ppt_pl2_sppt_max = 43, + .ppt_pl3_fppt_min = 15, + .ppt_pl3_fppt_max = 53 + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 7, + .ppt_pl1_spl_def = 15, + .ppt_pl1_spl_max = 25, + .ppt_pl2_sppt_min = 15, + .ppt_pl2_sppt_def = 20, + .ppt_pl2_sppt_max = 30, + .ppt_pl3_fppt_min = 15, + .ppt_pl3_fppt_def = 25, + .ppt_pl3_fppt_max = 35 + } + }, + }, + { + .matches = { + DMI_MATCH(DMI_BOARD_NAME, "RC72"), + }, + .driver_data = &(struct power_data) { + .ac_data = &(struct power_limits) { + .ppt_pl1_spl_min = 7, + .ppt_pl1_spl_max = 30, + .ppt_pl2_sppt_min = 15, + .ppt_pl2_sppt_max = 43, + .ppt_pl3_fppt_min = 15, + .ppt_pl3_fppt_max = 53 + }, + .dc_data = &(struct power_limits) { + .ppt_pl1_spl_min = 7, + .ppt_pl1_spl_def = 17, + .ppt_pl1_spl_max = 25, + .ppt_pl2_sppt_min = 15, + .ppt_pl2_sppt_def = 24, + .ppt_pl2_sppt_max = 30, + .ppt_pl3_fppt_min = 15, + .ppt_pl3_fppt_def = 30, + .ppt_pl3_fppt_max = 35 + } + }, + }, + {} +}; + +#endif /* _ASUS_ARMOURY_H_ */ diff --git a/drivers/platform/x86/asus-wmi.c b/drivers/platform/x86/asus-wmi.c index 38ef778e8c19..83fe67816329 100644 --- a/drivers/platform/x86/asus-wmi.c +++ b/drivers/platform/x86/asus-wmi.c @@ -55,8 +55,6 @@ module_param(fnlock_default, bool, 0444); #define to_asus_wmi_driver(pdrv) \ (container_of((pdrv), struct asus_wmi_driver, platform_driver)) -#define ASUS_WMI_MGMT_GUID "97845ED0-4E6D-11DE-8A39-0800200C9A66" - #define NOTIFY_BRNUP_MIN 0x11 #define NOTIFY_BRNUP_MAX 0x1f #define NOTIFY_BRNDOWN_MIN 0x20 @@ -105,8 +103,6 @@ module_param(fnlock_default, bool, 0444); #define USB_INTEL_XUSB2PR 0xD0 #define PCI_DEVICE_ID_INTEL_LYNXPOINT_LP_XHCI 0x9c31 -#define ASUS_ACPI_UID_ASUSWMI "ASUSWMI" - #define WMI_EVENT_MASK 0xFFFF #define FAN_CURVE_POINTS 8 @@ -142,16 +138,20 @@ module_param(fnlock_default, bool, 0444); #define ASUS_MINI_LED_2024_STRONG 0x01 #define ASUS_MINI_LED_2024_OFF 0x02 -/* Controls the power state of the USB0 hub on ROG Ally which input is on */ #define ASUS_USB0_PWR_EC0_CSEE "\\_SB.PCI0.SBRG.EC0.CSEE" -/* 300ms so far seems to produce a reliable result on AC and battery */ -#define ASUS_USB0_PWR_EC0_CSEE_WAIT 1500 +/* + * The period required to wait after screen off/on/s2idle.check in MS. + * Time here greatly impacts the wake behaviour. Used in suspend/wake. + */ +#define ASUS_USB0_PWR_EC0_CSEE_WAIT 600 +#define ASUS_USB0_PWR_EC0_CSEE_OFF 0xB7 +#define ASUS_USB0_PWR_EC0_CSEE_ON 0xB8 static const char * const ashs_ids[] = { "ATK4001", "ATK4002", NULL }; static int throttle_thermal_policy_write(struct asus_wmi *); -static const struct dmi_system_id asus_ally_mcu_quirk[] = { +static const struct dmi_system_id asus_rog_ally_device[] = { { .matches = { DMI_MATCH(DMI_BOARD_NAME, "RC71L"), @@ -274,9 +274,6 @@ struct asus_wmi { u32 tablet_switch_dev_id; bool tablet_switch_inverted; - /* The ROG Ally device requires the MCU USB device be disconnected before suspend */ - bool ally_mcu_usb_switch; - enum fan_type fan_type; enum fan_type gpu_fan_type; enum fan_type mid_fan_type; @@ -335,6 +332,16 @@ struct asus_wmi { struct asus_wmi_driver *driver; }; +/* Global to allow setting externally without requiring driver data */ +static enum asus_ally_mcu_hack use_ally_mcu_hack = ASUS_WMI_ALLY_MCU_HACK_INIT; + +#if IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) +static void asus_wmi_show_deprecated(void) +{ + pr_notice_once("Accessing attributes through /sys/bus/platform/asus_wmi is deprecated and will be removed in a future release. Please switch over to /sys/class/firmware_attributes.\n"); +} +#endif /* IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) */ + /* WMI ************************************************************************/ static int asus_wmi_evaluate_method3(u32 method_id, @@ -385,7 +392,7 @@ int asus_wmi_evaluate_method(u32 method_id, u32 arg0, u32 arg1, u32 *retval) { return asus_wmi_evaluate_method3(method_id, arg0, arg1, 0, retval); } -EXPORT_SYMBOL_GPL(asus_wmi_evaluate_method); +EXPORT_SYMBOL_NS_GPL(asus_wmi_evaluate_method, "ASUS_WMI"); static int asus_wmi_evaluate_method5(u32 method_id, u32 arg0, u32 arg1, u32 arg2, u32 arg3, u32 arg4, u32 *retval) @@ -549,12 +556,51 @@ static int asus_wmi_get_devstate(struct asus_wmi *asus, u32 dev_id, u32 *retval) return 0; } -static int asus_wmi_set_devstate(u32 dev_id, u32 ctrl_param, - u32 *retval) + +/** + * asus_wmi_get_devstate_dsts() - Get the WMI function state. + * @dev_id: The WMI method ID to call. + * @retval: A pointer to where to store the value returned from WMI. + * + * On success the return value is 0, and the retval is a valid value returned + * by the successful WMI function call otherwise an error is returned if the + * call failed, or if the WMI method ID is unsupported. + */ +int asus_wmi_get_devstate_dsts(u32 dev_id, u32 *retval) +{ + int err; + + err = asus_wmi_evaluate_method(ASUS_WMI_METHODID_DSTS, dev_id, 0, retval); + if (err) + return err; + + if (*retval == ASUS_WMI_UNSUPPORTED_METHOD) + return -ENODEV; + + return 0; +} +EXPORT_SYMBOL_NS_GPL(asus_wmi_get_devstate_dsts, "ASUS_WMI"); + +/** + * asus_wmi_set_devstate() - Set the WMI function state. + * @dev_id: The WMI function to call. + * @ctrl_param: The argument to be used for this WMI function. + * @retval: A pointer to where to store the value returned from WMI. + * + * The returned WMI function state if not checked here for error as + * asus_wmi_set_devstate() is not called unless first paired with a call to + * asus_wmi_get_devstate_dsts() to check that the WMI function is supported. + * + * On success the return value is 0, and the retval is a valid value returned + * by the successful WMI function call. An error value is returned only if the + * WMI function failed. + */ +int asus_wmi_set_devstate(u32 dev_id, u32 ctrl_param, u32 *retval) { return asus_wmi_evaluate_method(ASUS_WMI_METHODID_DEVS, dev_id, ctrl_param, retval); } +EXPORT_SYMBOL_NS_GPL(asus_wmi_set_devstate, "ASUS_WMI"); /* Helper for special devices with magic return codes */ static int asus_wmi_get_devstate_bits(struct asus_wmi *asus, @@ -687,6 +733,7 @@ static void asus_wmi_tablet_mode_get_state(struct asus_wmi *asus) } /* Charging mode, 1=Barrel, 2=USB ******************************************/ +#if IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) static ssize_t charge_mode_show(struct device *dev, struct device_attribute *attr, char *buf) { @@ -697,12 +744,16 @@ static ssize_t charge_mode_show(struct device *dev, if (result < 0) return result; + asus_wmi_show_deprecated(); + return sysfs_emit(buf, "%d\n", value & 0xff); } static DEVICE_ATTR_RO(charge_mode); +#endif /* IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) */ /* dGPU ********************************************************************/ +#if IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) static ssize_t dgpu_disable_show(struct device *dev, struct device_attribute *attr, char *buf) { @@ -713,6 +764,8 @@ static ssize_t dgpu_disable_show(struct device *dev, if (result < 0) return result; + asus_wmi_show_deprecated(); + return sysfs_emit(buf, "%d\n", result); } @@ -766,8 +819,10 @@ static ssize_t dgpu_disable_store(struct device *dev, return count; } static DEVICE_ATTR_RW(dgpu_disable); +#endif /* IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) */ /* eGPU ********************************************************************/ +#if IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) static ssize_t egpu_enable_show(struct device *dev, struct device_attribute *attr, char *buf) { @@ -778,6 +833,8 @@ static ssize_t egpu_enable_show(struct device *dev, if (result < 0) return result; + asus_wmi_show_deprecated(); + return sysfs_emit(buf, "%d\n", result); } @@ -834,8 +891,10 @@ static ssize_t egpu_enable_store(struct device *dev, return count; } static DEVICE_ATTR_RW(egpu_enable); +#endif /* IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) */ /* Is eGPU connected? *********************************************************/ +#if IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) static ssize_t egpu_connected_show(struct device *dev, struct device_attribute *attr, char *buf) { @@ -846,12 +905,16 @@ static ssize_t egpu_connected_show(struct device *dev, if (result < 0) return result; + asus_wmi_show_deprecated(); + return sysfs_emit(buf, "%d\n", result); } static DEVICE_ATTR_RO(egpu_connected); +#endif /* IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) */ /* gpu mux switch *************************************************************/ +#if IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) static ssize_t gpu_mux_mode_show(struct device *dev, struct device_attribute *attr, char *buf) { @@ -862,6 +925,8 @@ static ssize_t gpu_mux_mode_show(struct device *dev, if (result < 0) return result; + asus_wmi_show_deprecated(); + return sysfs_emit(buf, "%d\n", result); } @@ -920,6 +985,7 @@ static ssize_t gpu_mux_mode_store(struct device *dev, return count; } static DEVICE_ATTR_RW(gpu_mux_mode); +#endif /* IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) */ /* TUF Laptop Keyboard RGB Modes **********************************************/ static ssize_t kbd_rgb_mode_store(struct device *dev, @@ -1043,6 +1109,7 @@ static const struct attribute_group *kbd_rgb_mode_groups[] = { }; /* Tunable: PPT: Intel=PL1, AMD=SPPT *****************************************/ +#if IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) static ssize_t ppt_pl2_sppt_store(struct device *dev, struct device_attribute *attr, const char *buf, size_t count) @@ -1081,6 +1148,8 @@ static ssize_t ppt_pl2_sppt_show(struct device *dev, { struct asus_wmi *asus = dev_get_drvdata(dev); + asus_wmi_show_deprecated(); + return sysfs_emit(buf, "%u\n", asus->ppt_pl2_sppt); } static DEVICE_ATTR_RW(ppt_pl2_sppt); @@ -1123,6 +1192,8 @@ static ssize_t ppt_pl1_spl_show(struct device *dev, { struct asus_wmi *asus = dev_get_drvdata(dev); + asus_wmi_show_deprecated(); + return sysfs_emit(buf, "%u\n", asus->ppt_pl1_spl); } static DEVICE_ATTR_RW(ppt_pl1_spl); @@ -1166,6 +1237,8 @@ static ssize_t ppt_fppt_show(struct device *dev, { struct asus_wmi *asus = dev_get_drvdata(dev); + asus_wmi_show_deprecated(); + return sysfs_emit(buf, "%u\n", asus->ppt_fppt); } static DEVICE_ATTR_RW(ppt_fppt); @@ -1209,6 +1282,8 @@ static ssize_t ppt_apu_sppt_show(struct device *dev, { struct asus_wmi *asus = dev_get_drvdata(dev); + asus_wmi_show_deprecated(); + return sysfs_emit(buf, "%u\n", asus->ppt_apu_sppt); } static DEVICE_ATTR_RW(ppt_apu_sppt); @@ -1252,6 +1327,8 @@ static ssize_t ppt_platform_sppt_show(struct device *dev, { struct asus_wmi *asus = dev_get_drvdata(dev); + asus_wmi_show_deprecated(); + return sysfs_emit(buf, "%u\n", asus->ppt_platform_sppt); } static DEVICE_ATTR_RW(ppt_platform_sppt); @@ -1295,6 +1372,8 @@ static ssize_t nv_dynamic_boost_show(struct device *dev, { struct asus_wmi *asus = dev_get_drvdata(dev); + asus_wmi_show_deprecated(); + return sysfs_emit(buf, "%u\n", asus->nv_dynamic_boost); } static DEVICE_ATTR_RW(nv_dynamic_boost); @@ -1338,11 +1417,53 @@ static ssize_t nv_temp_target_show(struct device *dev, { struct asus_wmi *asus = dev_get_drvdata(dev); + asus_wmi_show_deprecated(); + return sysfs_emit(buf, "%u\n", asus->nv_temp_target); } static DEVICE_ATTR_RW(nv_temp_target); +#endif /* IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) */ /* Ally MCU Powersave ********************************************************/ + +/* + * The HID driver needs to check MCU version and set this to false if the MCU FW + * version is >= the minimum requirements. New FW do not need the hacks. + */ +void set_ally_mcu_hack(enum asus_ally_mcu_hack status) +{ + use_ally_mcu_hack = status; + pr_debug("%s Ally MCU suspend quirk\n", + status == ASUS_WMI_ALLY_MCU_HACK_ENABLED ? "Enabled" : "Disabled"); +} +EXPORT_SYMBOL_NS_GPL(set_ally_mcu_hack, "ASUS_WMI"); + +/* + * mcu_powersave should be enabled always, as it is fixed in MCU FW versions: + * - v313 for Ally X + * - v319 for Ally 1 + * The HID driver checks MCU versions and so should set this if requirements match + */ +void set_ally_mcu_powersave(bool enabled) +{ + int result, err; + + err = asus_wmi_set_devstate(ASUS_WMI_DEVID_MCU_POWERSAVE, enabled, &result); + if (err) { + pr_warn("Failed to set MCU powersave: %d\n", err); + return; + } + if (result > 1) { + pr_warn("Failed to set MCU powersave (result): 0x%x\n", result); + return; + } + + pr_debug("%s MCU Powersave\n", + enabled ? "Enabled" : "Disabled"); +} +EXPORT_SYMBOL_NS_GPL(set_ally_mcu_powersave, "ASUS_WMI"); + +#if IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) static ssize_t mcu_powersave_show(struct device *dev, struct device_attribute *attr, char *buf) { @@ -1353,6 +1474,8 @@ static ssize_t mcu_powersave_show(struct device *dev, if (result < 0) return result; + asus_wmi_show_deprecated(); + return sysfs_emit(buf, "%d\n", result); } @@ -1388,6 +1511,7 @@ static ssize_t mcu_powersave_store(struct device *dev, return count; } static DEVICE_ATTR_RW(mcu_powersave); +#endif /* IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) */ /* Battery ********************************************************************/ @@ -2261,6 +2385,7 @@ static int asus_wmi_rfkill_init(struct asus_wmi *asus) } /* Panel Overdrive ************************************************************/ +#if IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) static ssize_t panel_od_show(struct device *dev, struct device_attribute *attr, char *buf) { @@ -2271,6 +2396,8 @@ static ssize_t panel_od_show(struct device *dev, if (result < 0) return result; + asus_wmi_show_deprecated(); + return sysfs_emit(buf, "%d\n", result); } @@ -2307,9 +2434,10 @@ static ssize_t panel_od_store(struct device *dev, return count; } static DEVICE_ATTR_RW(panel_od); +#endif /* IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) */ /* Bootup sound ***************************************************************/ - +#if IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) static ssize_t boot_sound_show(struct device *dev, struct device_attribute *attr, char *buf) { @@ -2320,6 +2448,8 @@ static ssize_t boot_sound_show(struct device *dev, if (result < 0) return result; + asus_wmi_show_deprecated(); + return sysfs_emit(buf, "%d\n", result); } @@ -2355,8 +2485,10 @@ static ssize_t boot_sound_store(struct device *dev, return count; } static DEVICE_ATTR_RW(boot_sound); +#endif /* IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) */ /* Mini-LED mode **************************************************************/ +#if IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) static ssize_t mini_led_mode_show(struct device *dev, struct device_attribute *attr, char *buf) { @@ -2387,6 +2519,8 @@ static ssize_t mini_led_mode_show(struct device *dev, } } + asus_wmi_show_deprecated(); + return sysfs_emit(buf, "%d\n", value); } @@ -2457,10 +2591,13 @@ static ssize_t available_mini_led_mode_show(struct device *dev, return sysfs_emit(buf, "0 1 2\n"); } + asus_wmi_show_deprecated(); + return sysfs_emit(buf, "0\n"); } static DEVICE_ATTR_RO(available_mini_led_mode); +#endif /* IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) */ /* Quirks *********************************************************************/ @@ -3748,6 +3885,7 @@ static int throttle_thermal_policy_set_default(struct asus_wmi *asus) return throttle_thermal_policy_write(asus); } +#if IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) static ssize_t throttle_thermal_policy_show(struct device *dev, struct device_attribute *attr, char *buf) { @@ -3791,6 +3929,7 @@ static ssize_t throttle_thermal_policy_store(struct device *dev, * Throttle thermal policy: 0 - default, 1 - overboost, 2 - silent */ static DEVICE_ATTR_RW(throttle_thermal_policy); +#endif /* IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) */ /* Platform profile ***********************************************************/ static int asus_wmi_platform_profile_get(struct device *dev, @@ -3810,7 +3949,7 @@ static int asus_wmi_platform_profile_get(struct device *dev, *profile = PLATFORM_PROFILE_PERFORMANCE; break; case ASUS_THROTTLE_THERMAL_POLICY_SILENT: - *profile = PLATFORM_PROFILE_QUIET; + *profile = PLATFORM_PROFILE_LOW_POWER; break; default: return -EINVAL; @@ -3834,7 +3973,7 @@ static int asus_wmi_platform_profile_set(struct device *dev, case PLATFORM_PROFILE_BALANCED: tp = ASUS_THROTTLE_THERMAL_POLICY_DEFAULT; break; - case PLATFORM_PROFILE_QUIET: + case PLATFORM_PROFILE_LOW_POWER: tp = ASUS_THROTTLE_THERMAL_POLICY_SILENT; break; default: @@ -3847,7 +3986,7 @@ static int asus_wmi_platform_profile_set(struct device *dev, static int asus_wmi_platform_profile_probe(void *drvdata, unsigned long *choices) { - set_bit(PLATFORM_PROFILE_QUIET, choices); + set_bit(PLATFORM_PROFILE_LOW_POWER, choices); set_bit(PLATFORM_PROFILE_BALANCED, choices); set_bit(PLATFORM_PROFILE_PERFORMANCE, choices); @@ -4392,27 +4531,29 @@ static struct attribute *platform_attributes[] = { &dev_attr_camera.attr, &dev_attr_cardr.attr, &dev_attr_touchpad.attr, - &dev_attr_charge_mode.attr, - &dev_attr_egpu_enable.attr, - &dev_attr_egpu_connected.attr, - &dev_attr_dgpu_disable.attr, - &dev_attr_gpu_mux_mode.attr, &dev_attr_lid_resume.attr, &dev_attr_als_enable.attr, &dev_attr_fan_boost_mode.attr, - &dev_attr_throttle_thermal_policy.attr, - &dev_attr_ppt_pl2_sppt.attr, - &dev_attr_ppt_pl1_spl.attr, - &dev_attr_ppt_fppt.attr, - &dev_attr_ppt_apu_sppt.attr, - &dev_attr_ppt_platform_sppt.attr, - &dev_attr_nv_dynamic_boost.attr, - &dev_attr_nv_temp_target.attr, - &dev_attr_mcu_powersave.attr, - &dev_attr_boot_sound.attr, - &dev_attr_panel_od.attr, - &dev_attr_mini_led_mode.attr, - &dev_attr_available_mini_led_mode.attr, +#if IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) + &dev_attr_charge_mode.attr, + &dev_attr_egpu_enable.attr, + &dev_attr_egpu_connected.attr, + &dev_attr_dgpu_disable.attr, + &dev_attr_gpu_mux_mode.attr, + &dev_attr_ppt_pl2_sppt.attr, + &dev_attr_ppt_pl1_spl.attr, + &dev_attr_ppt_fppt.attr, + &dev_attr_ppt_apu_sppt.attr, + &dev_attr_ppt_platform_sppt.attr, + &dev_attr_nv_dynamic_boost.attr, + &dev_attr_nv_temp_target.attr, + &dev_attr_mcu_powersave.attr, + &dev_attr_boot_sound.attr, + &dev_attr_panel_od.attr, + &dev_attr_mini_led_mode.attr, + &dev_attr_available_mini_led_mode.attr, + &dev_attr_throttle_thermal_policy.attr, +#endif /* IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) */ NULL }; @@ -4434,7 +4575,11 @@ static umode_t asus_sysfs_is_visible(struct kobject *kobj, devid = ASUS_WMI_DEVID_LID_RESUME; else if (attr == &dev_attr_als_enable.attr) devid = ASUS_WMI_DEVID_ALS_ENABLE; - else if (attr == &dev_attr_charge_mode.attr) + else if (attr == &dev_attr_fan_boost_mode.attr) + ok = asus->fan_boost_mode_available; + +#if IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) + if (attr == &dev_attr_charge_mode.attr) devid = ASUS_WMI_DEVID_CHARGE_MODE; else if (attr == &dev_attr_egpu_enable.attr) ok = asus->egpu_enable_available; @@ -4472,6 +4617,7 @@ static umode_t asus_sysfs_is_visible(struct kobject *kobj, ok = asus->mini_led_dev_id != 0; else if (attr == &dev_attr_available_mini_led_mode.attr) ok = asus->mini_led_dev_id != 0; +#endif /* IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) */ if (devid != -1) { ok = !(asus_wmi_get_devstate_simple(asus, devid) < 0); @@ -4711,7 +4857,23 @@ static int asus_wmi_add(struct platform_device *pdev) if (err) goto fail_platform; + if (use_ally_mcu_hack == ASUS_WMI_ALLY_MCU_HACK_INIT) { + if (acpi_has_method(NULL, ASUS_USB0_PWR_EC0_CSEE) + && dmi_check_system(asus_rog_ally_device)) + use_ally_mcu_hack = ASUS_WMI_ALLY_MCU_HACK_ENABLED; + if (dmi_match(DMI_BOARD_NAME, "RC71")) { + /* + * These steps ensure the device is in a valid good state, this is + * especially important for the Ally 1 after a reboot. + */ + acpi_execute_simple_method(NULL, ASUS_USB0_PWR_EC0_CSEE, + ASUS_USB0_PWR_EC0_CSEE_ON); + msleep(ASUS_USB0_PWR_EC0_CSEE_WAIT); + } + } + /* ensure defaults for tunables */ +#if IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) asus->ppt_pl2_sppt = 5; asus->ppt_pl1_spl = 5; asus->ppt_apu_sppt = 5; @@ -4723,8 +4885,6 @@ static int asus_wmi_add(struct platform_device *pdev) asus->egpu_enable_available = asus_wmi_dev_is_present(asus, ASUS_WMI_DEVID_EGPU); asus->dgpu_disable_available = asus_wmi_dev_is_present(asus, ASUS_WMI_DEVID_DGPU); asus->kbd_rgb_state_available = asus_wmi_dev_is_present(asus, ASUS_WMI_DEVID_TUF_RGB_STATE); - asus->ally_mcu_usb_switch = acpi_has_method(NULL, ASUS_USB0_PWR_EC0_CSEE) - && dmi_check_system(asus_ally_mcu_quirk); if (asus_wmi_dev_is_present(asus, ASUS_WMI_DEVID_MINI_LED_MODE)) asus->mini_led_dev_id = ASUS_WMI_DEVID_MINI_LED_MODE; @@ -4735,17 +4895,18 @@ static int asus_wmi_add(struct platform_device *pdev) asus->gpu_mux_dev = ASUS_WMI_DEVID_GPU_MUX; else if (asus_wmi_dev_is_present(asus, ASUS_WMI_DEVID_GPU_MUX_VIVO)) asus->gpu_mux_dev = ASUS_WMI_DEVID_GPU_MUX_VIVO; - - if (asus_wmi_dev_is_present(asus, ASUS_WMI_DEVID_TUF_RGB_MODE)) - asus->kbd_rgb_dev = ASUS_WMI_DEVID_TUF_RGB_MODE; - else if (asus_wmi_dev_is_present(asus, ASUS_WMI_DEVID_TUF_RGB_MODE2)) - asus->kbd_rgb_dev = ASUS_WMI_DEVID_TUF_RGB_MODE2; +#endif /* IS_ENABLED(CONFIG_ASUS_WMI_DEPRECATED_ATTRS) */ if (asus_wmi_dev_is_present(asus, ASUS_WMI_DEVID_THROTTLE_THERMAL_POLICY)) asus->throttle_thermal_policy_dev = ASUS_WMI_DEVID_THROTTLE_THERMAL_POLICY; else if (asus_wmi_dev_is_present(asus, ASUS_WMI_DEVID_THROTTLE_THERMAL_POLICY_VIVO)) asus->throttle_thermal_policy_dev = ASUS_WMI_DEVID_THROTTLE_THERMAL_POLICY_VIVO; + if (asus_wmi_dev_is_present(asus, ASUS_WMI_DEVID_TUF_RGB_MODE)) + asus->kbd_rgb_dev = ASUS_WMI_DEVID_TUF_RGB_MODE; + else if (asus_wmi_dev_is_present(asus, ASUS_WMI_DEVID_TUF_RGB_MODE2)) + asus->kbd_rgb_dev = ASUS_WMI_DEVID_TUF_RGB_MODE2; + err = fan_boost_mode_check_present(asus); if (err) goto fail_fan_boost_mode; @@ -4910,34 +5071,6 @@ static int asus_hotk_resume(struct device *device) return 0; } -static int asus_hotk_resume_early(struct device *device) -{ - struct asus_wmi *asus = dev_get_drvdata(device); - - if (asus->ally_mcu_usb_switch) { - /* sleep required to prevent USB0 being yanked then reappearing rapidly */ - if (ACPI_FAILURE(acpi_execute_simple_method(NULL, ASUS_USB0_PWR_EC0_CSEE, 0xB8))) - dev_err(device, "ROG Ally MCU failed to connect USB dev\n"); - else - msleep(ASUS_USB0_PWR_EC0_CSEE_WAIT); - } - return 0; -} - -static int asus_hotk_prepare(struct device *device) -{ - struct asus_wmi *asus = dev_get_drvdata(device); - - if (asus->ally_mcu_usb_switch) { - /* sleep required to ensure USB0 is disabled before sleep continues */ - if (ACPI_FAILURE(acpi_execute_simple_method(NULL, ASUS_USB0_PWR_EC0_CSEE, 0xB7))) - dev_err(device, "ROG Ally MCU failed to disconnect USB dev\n"); - else - msleep(ASUS_USB0_PWR_EC0_CSEE_WAIT); - } - return 0; -} - static int asus_hotk_restore(struct device *device) { struct asus_wmi *asus = dev_get_drvdata(device); @@ -4978,11 +5111,34 @@ static int asus_hotk_restore(struct device *device) return 0; } +static void asus_ally_s2idle_restore(void) +{ + if (use_ally_mcu_hack == ASUS_WMI_ALLY_MCU_HACK_ENABLED) { + acpi_execute_simple_method(NULL, ASUS_USB0_PWR_EC0_CSEE, + ASUS_USB0_PWR_EC0_CSEE_ON); + msleep(ASUS_USB0_PWR_EC0_CSEE_WAIT); + } +} + +static int asus_hotk_prepare(struct device *device) +{ + if (use_ally_mcu_hack == ASUS_WMI_ALLY_MCU_HACK_ENABLED) { + acpi_execute_simple_method(NULL, ASUS_USB0_PWR_EC0_CSEE, + ASUS_USB0_PWR_EC0_CSEE_OFF); + msleep(ASUS_USB0_PWR_EC0_CSEE_WAIT); + } + return 0; +} + +/* Use only for Ally devices due to the wake_on_ac */ +static struct acpi_s2idle_dev_ops asus_ally_s2idle_dev_ops = { + .restore = asus_ally_s2idle_restore, +}; + static const struct dev_pm_ops asus_pm_ops = { .thaw = asus_hotk_thaw, .restore = asus_hotk_restore, .resume = asus_hotk_resume, - .resume_early = asus_hotk_resume_early, .prepare = asus_hotk_prepare, }; @@ -5010,6 +5166,10 @@ static int asus_wmi_probe(struct platform_device *pdev) return ret; } + ret = acpi_register_lps0_dev(&asus_ally_s2idle_dev_ops); + if (ret) + pr_warn("failed to register LPS0 sleep handler in asus-wmi\n"); + return asus_wmi_add(pdev); } @@ -5042,6 +5202,7 @@ EXPORT_SYMBOL_GPL(asus_wmi_register_driver); void asus_wmi_unregister_driver(struct asus_wmi_driver *driver) { + acpi_unregister_lps0_dev(&asus_ally_s2idle_dev_ops); platform_device_unregister(driver->platform_device); platform_driver_unregister(&driver->platform_driver); used = false; diff --git a/include/linux/platform_data/x86/asus-wmi.h b/include/linux/platform_data/x86/asus-wmi.h index 783e2a336861..78261ea49995 100644 --- a/include/linux/platform_data/x86/asus-wmi.h +++ b/include/linux/platform_data/x86/asus-wmi.h @@ -6,6 +6,9 @@ #include #include +#define ASUS_WMI_MGMT_GUID "97845ED0-4E6D-11DE-8A39-0800200C9A66" +#define ASUS_ACPI_UID_ASUSWMI "ASUSWMI" + /* WMI Methods */ #define ASUS_WMI_METHODID_SPEC 0x43455053 /* BIOS SPECification */ #define ASUS_WMI_METHODID_SFBD 0x44424653 /* Set First Boot Device */ @@ -73,12 +76,14 @@ #define ASUS_WMI_DEVID_THROTTLE_THERMAL_POLICY_VIVO 0x00110019 /* Misc */ +#define ASUS_WMI_DEVID_PANEL_HD 0x0005001C #define ASUS_WMI_DEVID_PANEL_OD 0x00050019 #define ASUS_WMI_DEVID_CAMERA 0x00060013 #define ASUS_WMI_DEVID_LID_FLIP 0x00060062 #define ASUS_WMI_DEVID_LID_FLIP_ROG 0x00060077 #define ASUS_WMI_DEVID_MINI_LED_MODE 0x0005001E #define ASUS_WMI_DEVID_MINI_LED_MODE2 0x0005002E +#define ASUS_WMI_DEVID_SCREEN_AUTO_BRIGHTNESS 0x0005002A /* Storage */ #define ASUS_WMI_DEVID_CARDREADER 0x00080013 @@ -133,6 +138,16 @@ /* dgpu on/off */ #define ASUS_WMI_DEVID_DGPU 0x00090020 +/* Intel E-core and P-core configuration in a format 0x0[E]0[P] */ +#define ASUS_WMI_DEVID_CORES 0x001200D2 + /* Maximum Intel E-core and P-core availability */ +#define ASUS_WMI_DEVID_CORES_MAX 0x001200D3 + +#define ASUS_WMI_DEVID_APU_MEM 0x000600C1 + +#define ASUS_WMI_DEVID_DGPU_BASE_TGP 0x00120099 +#define ASUS_WMI_DEVID_DGPU_SET_TGP 0x00120098 + /* gpu mux switch, 0 = dGPU, 1 = Optimus */ #define ASUS_WMI_DEVID_GPU_MUX 0x00090016 #define ASUS_WMI_DEVID_GPU_MUX_VIVO 0x00090026 @@ -157,9 +172,37 @@ #define ASUS_WMI_DSTS_MAX_BRIGTH_MASK 0x0000FF00 #define ASUS_WMI_DSTS_LIGHTBAR_MASK 0x0000000F +enum asus_ally_mcu_hack { + ASUS_WMI_ALLY_MCU_HACK_INIT, + ASUS_WMI_ALLY_MCU_HACK_ENABLED, + ASUS_WMI_ALLY_MCU_HACK_DISABLED, +}; + #if IS_REACHABLE(CONFIG_ASUS_WMI) +void set_ally_mcu_hack(enum asus_ally_mcu_hack status); +void set_ally_mcu_powersave(bool enabled); +int asus_wmi_get_devstate_dsts(u32 dev_id, u32 *retval); +int asus_wmi_set_devstate(u32 dev_id, u32 ctrl_param, u32 *retval); int asus_wmi_evaluate_method(u32 method_id, u32 arg0, u32 arg1, u32 *retval); #else +static inline void set_ally_mcu_hack(enum asus_ally_mcu_hack status) +{ +} +static inline void set_ally_mcu_powersave(bool enabled) +{ +} +static inline int asus_wmi_set_devstate(u32 dev_id, u32 ctrl_param, u32 *retval) +{ + return -ENODEV; +} +static inline int asus_wmi_get_devstate_dsts(u32 dev_id, u32 *retval) +{ + return -ENODEV; +} +static inline int asus_wmi_set_devstate(u32 dev_id, u32 ctrl_param, u32 *retval) +{ + return -ENODEV; +} static inline int asus_wmi_evaluate_method(u32 method_id, u32 arg0, u32 arg1, u32 *retval) { -- 2.49.0.391.g4bbb303af6 From 768ec00d141959898429a82cf6f7d610b8c90577 Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Sun, 20 Apr 2025 15:32:06 +0200 Subject: [PATCH 4/9] bbr3 Signed-off-by: Peter Jung --- include/linux/tcp.h | 6 +- include/net/inet_connection_sock.h | 4 +- include/net/tcp.h | 72 +- include/uapi/linux/inet_diag.h | 23 + include/uapi/linux/rtnetlink.h | 4 +- include/uapi/linux/tcp.h | 1 + net/ipv4/Kconfig | 21 +- net/ipv4/bpf_tcp_ca.c | 4 +- net/ipv4/tcp.c | 3 + net/ipv4/tcp_bbr.c | 2231 +++++++++++++++++++++------- net/ipv4/tcp_cong.c | 1 + net/ipv4/tcp_input.c | 40 +- net/ipv4/tcp_minisocks.c | 2 + net/ipv4/tcp_output.c | 48 +- net/ipv4/tcp_rate.c | 30 +- net/ipv4/tcp_timer.c | 1 + 16 files changed, 1937 insertions(+), 554 deletions(-) diff --git a/include/linux/tcp.h b/include/linux/tcp.h index f88daaa76d83..e569fd1ed7e8 100644 --- a/include/linux/tcp.h +++ b/include/linux/tcp.h @@ -243,7 +243,8 @@ struct tcp_sock { /* OOO segments go in this rbtree. Socket lock must be held. */ struct rb_root out_of_order_queue; u32 snd_ssthresh; /* Slow start size threshold */ - u8 recvmsg_inq : 1;/* Indicate # of bytes in queue upon recvmsg */ + u32 recvmsg_inq : 1,/* Indicate # of bytes in queue upon recvmsg */ + fast_ack_mode:1;/* ack ASAP if >1 rcv_mss received? */ __cacheline_group_end(tcp_sock_read_rx); /* TX read-write hotpath cache lines */ @@ -300,7 +301,8 @@ struct tcp_sock { */ struct tcp_options_received rx_opt; u8 nonagle : 4,/* Disable Nagle algorithm? */ - rate_app_limited:1; /* rate_{delivered,interval_us} limited? */ + rate_app_limited:1, /* rate_{delivered,interval_us} limited? */ + tlp_orig_data_app_limited:1; /* app-limited before TLP rtx? */ __cacheline_group_end(tcp_sock_write_txrx); /* RX read-write hotpath cache lines */ diff --git a/include/net/inet_connection_sock.h b/include/net/inet_connection_sock.h index c7f42844c79a..170250145598 100644 --- a/include/net/inet_connection_sock.h +++ b/include/net/inet_connection_sock.h @@ -137,8 +137,8 @@ struct inet_connection_sock { u32 icsk_probes_tstamp; u32 icsk_user_timeout; - u64 icsk_ca_priv[104 / sizeof(u64)]; -#define ICSK_CA_PRIV_SIZE sizeof_field(struct inet_connection_sock, icsk_ca_priv) +#define ICSK_CA_PRIV_SIZE (144) + u64 icsk_ca_priv[ICSK_CA_PRIV_SIZE / sizeof(u64)]; }; #define ICSK_TIME_RETRANS 1 /* Retransmit timer */ diff --git a/include/net/tcp.h b/include/net/tcp.h index 2d08473a6dc0..aa80dd0abe5a 100644 --- a/include/net/tcp.h +++ b/include/net/tcp.h @@ -376,6 +376,8 @@ static inline void tcp_dec_quickack_mode(struct sock *sk) #define TCP_ECN_QUEUE_CWR 2 #define TCP_ECN_DEMAND_CWR 4 #define TCP_ECN_SEEN 8 +#define TCP_ECN_LOW 16 +#define TCP_ECN_ECT_PERMANENT 32 enum tcp_tw_status { TCP_TW_SUCCESS = 0, @@ -796,6 +798,15 @@ static inline void tcp_fast_path_check(struct sock *sk) u32 tcp_delack_max(const struct sock *sk); +static inline void tcp_set_ecn_low_from_dst(struct sock *sk, + const struct dst_entry *dst) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (dst_feature(dst, RTAX_FEATURE_ECN_LOW)) + tp->ecn_flags |= TCP_ECN_LOW; +} + /* Compute the actual rto_min value */ static inline u32 tcp_rto_min(const struct sock *sk) { @@ -901,6 +912,11 @@ static inline u32 tcp_stamp_us_delta(u64 t1, u64 t0) return max_t(s64, t1 - t0, 0); } +static inline u32 tcp_stamp32_us_delta(u32 t1, u32 t0) +{ + return max_t(s32, t1 - t0, 0); +} + /* provide the departure time in us unit */ static inline u64 tcp_skb_timestamp_us(const struct sk_buff *skb) { @@ -990,9 +1006,14 @@ struct tcp_skb_cb { /* pkts S/ACKed so far upon tx of skb, incl retrans: */ __u32 delivered; /* start of send pipeline phase */ - u64 first_tx_mstamp; + u32 first_tx_mstamp; /* when we reached the "delivered" count */ - u64 delivered_mstamp; + u32 delivered_mstamp; +#define TCPCB_IN_FLIGHT_BITS 20 +#define TCPCB_IN_FLIGHT_MAX ((1U << TCPCB_IN_FLIGHT_BITS) - 1) + u32 in_flight:20, /* packets in flight at transmit */ + unused2:12; + u32 lost; /* packets lost so far upon tx of skb */ } tx; /* only used for outgoing skbs */ union { struct inet_skb_parm h4; @@ -1105,6 +1126,7 @@ enum tcp_ca_event { CA_EVENT_LOSS, /* loss timeout */ CA_EVENT_ECN_NO_CE, /* ECT set, but not CE marked */ CA_EVENT_ECN_IS_CE, /* received CE marked IP packet */ + CA_EVENT_TLP_RECOVERY, /* a lost segment was repaired by TLP probe */ }; /* Information about inbound ACK, passed to cong_ops->in_ack_event() */ @@ -1127,7 +1149,11 @@ enum tcp_ca_ack_event_flags { #define TCP_CONG_NON_RESTRICTED 0x1 /* Requires ECN/ECT set on all packets */ #define TCP_CONG_NEEDS_ECN 0x2 -#define TCP_CONG_MASK (TCP_CONG_NON_RESTRICTED | TCP_CONG_NEEDS_ECN) +/* Wants notification of CE events (CA_EVENT_ECN_IS_CE, CA_EVENT_ECN_NO_CE). */ +#define TCP_CONG_WANTS_CE_EVENTS 0x4 +#define TCP_CONG_MASK (TCP_CONG_NON_RESTRICTED | \ + TCP_CONG_NEEDS_ECN | \ + TCP_CONG_WANTS_CE_EVENTS) union tcp_cc_info; @@ -1147,10 +1173,13 @@ struct ack_sample { */ struct rate_sample { u64 prior_mstamp; /* starting timestamp for interval */ + u32 prior_lost; /* tp->lost at "prior_mstamp" */ u32 prior_delivered; /* tp->delivered at "prior_mstamp" */ u32 prior_delivered_ce;/* tp->delivered_ce at "prior_mstamp" */ + u32 tx_in_flight; /* packets in flight at starting timestamp */ + s32 lost; /* number of packets lost over interval */ s32 delivered; /* number of packets delivered over interval */ - s32 delivered_ce; /* number of packets delivered w/ CE marks*/ + s32 delivered_ce; /* packets delivered w/ CE mark over interval */ long interval_us; /* time for tp->delivered to incr "delivered" */ u32 snd_interval_us; /* snd interval for delivered packets */ u32 rcv_interval_us; /* rcv interval for delivered packets */ @@ -1161,7 +1190,9 @@ struct rate_sample { u32 last_end_seq; /* end_seq of most recently ACKed packet */ bool is_app_limited; /* is sample from packet with bubble in pipe? */ bool is_retrans; /* is sample from retransmission? */ + bool is_acking_tlp_retrans_seq; /* ACKed a TLP retransmit sequence? */ bool is_ack_delayed; /* is this (likely) a delayed ACK? */ + bool is_ece; /* did this ACK have ECN marked? */ }; struct tcp_congestion_ops { @@ -1185,8 +1216,11 @@ struct tcp_congestion_ops { /* hook for packet ack accounting (optional) */ void (*pkts_acked)(struct sock *sk, const struct ack_sample *sample); - /* override sysctl_tcp_min_tso_segs */ - u32 (*min_tso_segs)(struct sock *sk); + /* pick target number of segments per TSO/GSO skb (optional): */ + u32 (*tso_segs)(struct sock *sk, unsigned int mss_now); + + /* react to a specific lost skb (optional) */ + void (*skb_marked_lost)(struct sock *sk, const struct sk_buff *skb); /* call when packets are delivered to update cwnd and pacing rate, * after all the ca_state processing. (optional) @@ -1252,6 +1286,14 @@ static inline char *tcp_ca_get_name_by_key(u32 key, char *buffer) } #endif +static inline bool tcp_ca_wants_ce_events(const struct sock *sk) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + + return icsk->icsk_ca_ops->flags & (TCP_CONG_NEEDS_ECN | + TCP_CONG_WANTS_CE_EVENTS); +} + static inline bool tcp_ca_needs_ecn(const struct sock *sk) { const struct inet_connection_sock *icsk = inet_csk(sk); @@ -1271,6 +1313,7 @@ static inline void tcp_ca_event(struct sock *sk, const enum tcp_ca_event event) void tcp_set_ca_state(struct sock *sk, const u8 ca_state); /* From tcp_rate.c */ +void tcp_set_tx_in_flight(struct sock *sk, struct sk_buff *skb); void tcp_rate_skb_sent(struct sock *sk, struct sk_buff *skb); void tcp_rate_skb_delivered(struct sock *sk, struct sk_buff *skb, struct rate_sample *rs); @@ -1283,6 +1326,21 @@ static inline bool tcp_skb_sent_after(u64 t1, u64 t2, u32 seq1, u32 seq2) return t1 > t2 || (t1 == t2 && after(seq1, seq2)); } +/* If a retransmit failed due to local qdisc congestion or other local issues, + * then we may have called tcp_set_skb_tso_segs() to increase the number of + * segments in the skb without increasing the tx.in_flight. In all other cases, + * the tx.in_flight should be at least as big as the pcount of the sk_buff. We + * do not have the state to know whether a retransmit failed due to local qdisc + * congestion or other local issues, so to avoid spurious warnings we consider + * that any skb marked lost may have suffered that fate. + */ +static inline bool tcp_skb_tx_in_flight_is_suspicious(u32 skb_pcount, + u32 skb_sacked_flags, + u32 tx_in_flight) +{ + return (skb_pcount > tx_in_flight) && !(skb_sacked_flags & TCPCB_LOST); +} + /* These functions determine how the current flow behaves in respect of SACK * handling. SACK is negotiated with the peer, and therefore it can vary * between different flows. @@ -2434,7 +2492,7 @@ struct tcp_plb_state { u8 consec_cong_rounds:5, /* consecutive congested rounds */ unused:3; u32 pause_until; /* jiffies32 when PLB can resume rerouting */ -}; +} __attribute__ ((__packed__)); static inline void tcp_plb_init(const struct sock *sk, struct tcp_plb_state *plb) diff --git a/include/uapi/linux/inet_diag.h b/include/uapi/linux/inet_diag.h index 86bb2e8b17c9..9d9a3eb2ce9b 100644 --- a/include/uapi/linux/inet_diag.h +++ b/include/uapi/linux/inet_diag.h @@ -229,6 +229,29 @@ struct tcp_bbr_info { __u32 bbr_min_rtt; /* min-filtered RTT in uSec */ __u32 bbr_pacing_gain; /* pacing gain shifted left 8 bits */ __u32 bbr_cwnd_gain; /* cwnd gain shifted left 8 bits */ + __u32 bbr_bw_hi_lsb; /* lower 32 bits of bw_hi */ + __u32 bbr_bw_hi_msb; /* upper 32 bits of bw_hi */ + __u32 bbr_bw_lo_lsb; /* lower 32 bits of bw_lo */ + __u32 bbr_bw_lo_msb; /* upper 32 bits of bw_lo */ + __u8 bbr_mode; /* current bbr_mode in state machine */ + __u8 bbr_phase; /* current state machine phase */ + __u8 unused1; /* alignment padding; not used yet */ + __u8 bbr_version; /* BBR algorithm version */ + __u32 bbr_inflight_lo; /* lower short-term data volume bound */ + __u32 bbr_inflight_hi; /* higher long-term data volume bound */ + __u32 bbr_extra_acked; /* max excess packets ACKed in epoch */ +}; + +/* TCP BBR congestion control bbr_phase as reported in netlink/ss stats. */ +enum tcp_bbr_phase { + BBR_PHASE_INVALID = 0, + BBR_PHASE_STARTUP = 1, + BBR_PHASE_DRAIN = 2, + BBR_PHASE_PROBE_RTT = 3, + BBR_PHASE_PROBE_BW_UP = 4, + BBR_PHASE_PROBE_BW_DOWN = 5, + BBR_PHASE_PROBE_BW_CRUISE = 6, + BBR_PHASE_PROBE_BW_REFILL = 7, }; union tcp_cc_info { diff --git a/include/uapi/linux/rtnetlink.h b/include/uapi/linux/rtnetlink.h index 66c3903d29cf..dfdbc1c0b606 100644 --- a/include/uapi/linux/rtnetlink.h +++ b/include/uapi/linux/rtnetlink.h @@ -516,12 +516,14 @@ enum { #define RTAX_FEATURE_TIMESTAMP (1 << 2) /* unused */ #define RTAX_FEATURE_ALLFRAG (1 << 3) /* unused */ #define RTAX_FEATURE_TCP_USEC_TS (1 << 4) +#define RTAX_FEATURE_ECN_LOW (1 << 5) #define RTAX_FEATURE_MASK (RTAX_FEATURE_ECN | \ RTAX_FEATURE_SACK | \ RTAX_FEATURE_TIMESTAMP | \ RTAX_FEATURE_ALLFRAG | \ - RTAX_FEATURE_TCP_USEC_TS) + RTAX_FEATURE_TCP_USEC_TS | \ + RTAX_FEATURE_ECN_LOW) struct rta_session { __u8 proto; diff --git a/include/uapi/linux/tcp.h b/include/uapi/linux/tcp.h index dbf896f3146c..92b6d6472951 100644 --- a/include/uapi/linux/tcp.h +++ b/include/uapi/linux/tcp.h @@ -178,6 +178,7 @@ enum tcp_fastopen_client_fail { #define TCPI_OPT_ECN_SEEN 16 /* we received at least one packet with ECT */ #define TCPI_OPT_SYN_DATA 32 /* SYN-ACK acked data in SYN sent or rcvd */ #define TCPI_OPT_USEC_TS 64 /* usec timestamps */ +#define TCPI_OPT_ECN_LOW 128 /* Low-latency ECN enabled at conn init */ /* * Sender's congestion state indicating normal or abnormal situations diff --git a/net/ipv4/Kconfig b/net/ipv4/Kconfig index 6d2c97f8e9ef..ddc116ef22cb 100644 --- a/net/ipv4/Kconfig +++ b/net/ipv4/Kconfig @@ -669,15 +669,18 @@ config TCP_CONG_BBR default n help - BBR (Bottleneck Bandwidth and RTT) TCP congestion control aims to - maximize network utilization and minimize queues. It builds an explicit - model of the bottleneck delivery rate and path round-trip propagation - delay. It tolerates packet loss and delay unrelated to congestion. It - can operate over LAN, WAN, cellular, wifi, or cable modem links. It can - coexist with flows that use loss-based congestion control, and can - operate with shallow buffers, deep buffers, bufferbloat, policers, or - AQM schemes that do not provide a delay signal. It requires the fq - ("Fair Queue") pacing packet scheduler. + BBR (Bottleneck Bandwidth and RTT) TCP congestion control is a + model-based congestion control algorithm that aims to maximize + network utilization, keep queues and retransmit rates low, and to be + able to coexist with Reno/CUBIC in common scenarios. It builds an + explicit model of the network path. It tolerates a targeted degree + of random packet loss and delay. It can operate over LAN, WAN, + cellular, wifi, or cable modem links, and can use shallow-threshold + ECN signals. It can coexist to some degree with flows that use + loss-based congestion control, and can operate with shallow buffers, + deep buffers, bufferbloat, policers, or AQM schemes that do not + provide a delay signal. It requires pacing, using either TCP internal + pacing or the fq ("Fair Queue") pacing packet scheduler. choice prompt "Default TCP congestion control" diff --git a/net/ipv4/bpf_tcp_ca.c b/net/ipv4/bpf_tcp_ca.c index 554804774628..fb6ab6ca8440 100644 --- a/net/ipv4/bpf_tcp_ca.c +++ b/net/ipv4/bpf_tcp_ca.c @@ -280,7 +280,7 @@ static void bpf_tcp_ca_pkts_acked(struct sock *sk, const struct ack_sample *samp { } -static u32 bpf_tcp_ca_min_tso_segs(struct sock *sk) +static u32 bpf_tcp_ca_tso_segs(struct sock *sk, unsigned int mss_now) { return 0; } @@ -315,7 +315,7 @@ static struct tcp_congestion_ops __bpf_ops_tcp_congestion_ops = { .cwnd_event = bpf_tcp_ca_cwnd_event, .in_ack_event = bpf_tcp_ca_in_ack_event, .pkts_acked = bpf_tcp_ca_pkts_acked, - .min_tso_segs = bpf_tcp_ca_min_tso_segs, + .tso_segs = bpf_tcp_ca_tso_segs, .cong_control = bpf_tcp_ca_cong_control, .undo_cwnd = bpf_tcp_ca_undo_cwnd, .sndbuf_expand = bpf_tcp_ca_sndbuf_expand, diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index 57df7c1d2faa..47605d71f68b 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -3398,6 +3398,7 @@ int tcp_disconnect(struct sock *sk, int flags) tp->rx_opt.dsack = 0; tp->rx_opt.num_sacks = 0; tp->rcv_ooopack = 0; + tp->fast_ack_mode = 0; /* Clean up fastopen related fields */ @@ -4124,6 +4125,8 @@ void tcp_get_info(struct sock *sk, struct tcp_info *info) info->tcpi_options |= TCPI_OPT_ECN; if (tp->ecn_flags & TCP_ECN_SEEN) info->tcpi_options |= TCPI_OPT_ECN_SEEN; + if (tp->ecn_flags & TCP_ECN_LOW) + info->tcpi_options |= TCPI_OPT_ECN_LOW; if (tp->syn_data_acked) info->tcpi_options |= TCPI_OPT_SYN_DATA; if (tp->tcp_usec_ts) diff --git a/net/ipv4/tcp_bbr.c b/net/ipv4/tcp_bbr.c index 760941e55153..516a5daac694 100644 --- a/net/ipv4/tcp_bbr.c +++ b/net/ipv4/tcp_bbr.c @@ -1,18 +1,19 @@ -/* Bottleneck Bandwidth and RTT (BBR) congestion control +/* BBR (Bottleneck Bandwidth and RTT) congestion control * - * BBR congestion control computes the sending rate based on the delivery - * rate (throughput) estimated from ACKs. In a nutshell: + * BBR is a model-based congestion control algorithm that aims for low queues, + * low loss, and (bounded) Reno/CUBIC coexistence. To maintain a model of the + * network path, it uses measurements of bandwidth and RTT, as well as (if they + * occur) packet loss and/or shallow-threshold ECN signals. Note that although + * it can use ECN or loss signals explicitly, it does not require either; it + * can bound its in-flight data based on its estimate of the BDP. * - * On each ACK, update our model of the network path: - * bottleneck_bandwidth = windowed_max(delivered / elapsed, 10 round trips) - * min_rtt = windowed_min(rtt, 10 seconds) - * pacing_rate = pacing_gain * bottleneck_bandwidth - * cwnd = max(cwnd_gain * bottleneck_bandwidth * min_rtt, 4) - * - * The core algorithm does not react directly to packet losses or delays, - * although BBR may adjust the size of next send per ACK when loss is - * observed, or adjust the sending rate if it estimates there is a - * traffic policer, in order to keep the drop rate reasonable. + * The model has both higher and lower bounds for the operating range: + * lo: bw_lo, inflight_lo: conservative short-term lower bound + * hi: bw_hi, inflight_hi: robust long-term upper bound + * The bandwidth-probing time scale is (a) extended dynamically based on + * estimated BDP to improve coexistence with Reno/CUBIC; (b) bounded by + * an interactive wall-clock time-scale to be more scalable and responsive + * than Reno and CUBIC. * * Here is a state transition diagram for BBR: * @@ -65,6 +66,13 @@ #include #include +#include +#include "tcp_dctcp.h" + +#define BBR_VERSION 3 + +#define bbr_param(sk,name) (bbr_ ## name) + /* Scale factor for rate in pkt/uSec unit to avoid truncation in bandwidth * estimation. The rate unit ~= (1500 bytes / 1 usec / 2^24) ~= 715 bps. * This handles bandwidths from 0.06pps (715bps) to 256Mpps (3Tbps) in a u32. @@ -85,36 +93,41 @@ enum bbr_mode { BBR_PROBE_RTT, /* cut inflight to min to probe min_rtt */ }; +/* How does the incoming ACK stream relate to our bandwidth probing? */ +enum bbr_ack_phase { + BBR_ACKS_INIT, /* not probing; not getting probe feedback */ + BBR_ACKS_REFILLING, /* sending at est. bw to fill pipe */ + BBR_ACKS_PROBE_STARTING, /* inflight rising to probe bw */ + BBR_ACKS_PROBE_FEEDBACK, /* getting feedback from bw probing */ + BBR_ACKS_PROBE_STOPPING, /* stopped probing; still getting feedback */ +}; + /* BBR congestion control block */ struct bbr { u32 min_rtt_us; /* min RTT in min_rtt_win_sec window */ u32 min_rtt_stamp; /* timestamp of min_rtt_us */ u32 probe_rtt_done_stamp; /* end time for BBR_PROBE_RTT mode */ - struct minmax bw; /* Max recent delivery rate in pkts/uS << 24 */ - u32 rtt_cnt; /* count of packet-timed rounds elapsed */ + u32 probe_rtt_min_us; /* min RTT in probe_rtt_win_ms win */ + u32 probe_rtt_min_stamp; /* timestamp of probe_rtt_min_us*/ u32 next_rtt_delivered; /* scb->tx.delivered at end of round */ u64 cycle_mstamp; /* time of this cycle phase start */ - u32 mode:3, /* current bbr_mode in state machine */ + u32 mode:2, /* current bbr_mode in state machine */ prev_ca_state:3, /* CA state on previous ACK */ - packet_conservation:1, /* use packet conservation? */ round_start:1, /* start of packet-timed tx->ack round? */ + ce_state:1, /* If most recent data has CE bit set */ + bw_probe_up_rounds:5, /* cwnd-limited rounds in PROBE_UP */ + try_fast_path:1, /* can we take fast path? */ idle_restart:1, /* restarting after idle? */ probe_rtt_round_done:1, /* a BBR_PROBE_RTT round at 4 pkts? */ - unused:13, - lt_is_sampling:1, /* taking long-term ("LT") samples now? */ - lt_rtt_cnt:7, /* round trips in long-term interval */ - lt_use_bw:1; /* use lt_bw as our bw estimate? */ - u32 lt_bw; /* LT est delivery rate in pkts/uS << 24 */ - u32 lt_last_delivered; /* LT intvl start: tp->delivered */ - u32 lt_last_stamp; /* LT intvl start: tp->delivered_mstamp */ - u32 lt_last_lost; /* LT intvl start: tp->lost */ + init_cwnd:7, /* initial cwnd */ + unused_1:10; u32 pacing_gain:10, /* current gain for setting pacing rate */ cwnd_gain:10, /* current gain for setting cwnd */ full_bw_reached:1, /* reached full bw in Startup? */ full_bw_cnt:2, /* number of rounds without large bw gains */ - cycle_idx:3, /* current index in pacing_gain cycle array */ + cycle_idx:2, /* current index in pacing_gain cycle array */ has_seen_rtt:1, /* have we seen an RTT sample yet? */ - unused_b:5; + unused_2:6; u32 prior_cwnd; /* prior cwnd upon entering loss recovery */ u32 full_bw; /* recent bw, to estimate if pipe is full */ @@ -124,19 +137,67 @@ struct bbr { u32 ack_epoch_acked:20, /* packets (S)ACKed in sampling epoch */ extra_acked_win_rtts:5, /* age of extra_acked, in round trips */ extra_acked_win_idx:1, /* current index in extra_acked array */ - unused_c:6; + /* BBR v3 state: */ + full_bw_now:1, /* recently reached full bw plateau? */ + startup_ecn_rounds:2, /* consecutive hi ECN STARTUP rounds */ + loss_in_cycle:1, /* packet loss in this cycle? */ + ecn_in_cycle:1, /* ECN in this cycle? */ + unused_3:1; + u32 loss_round_delivered; /* scb->tx.delivered ending loss round */ + u32 undo_bw_lo; /* bw_lo before latest losses */ + u32 undo_inflight_lo; /* inflight_lo before latest losses */ + u32 undo_inflight_hi; /* inflight_hi before latest losses */ + u32 bw_latest; /* max delivered bw in last round trip */ + u32 bw_lo; /* lower bound on sending bandwidth */ + u32 bw_hi[2]; /* max recent measured bw sample */ + u32 inflight_latest; /* max delivered data in last round trip */ + u32 inflight_lo; /* lower bound of inflight data range */ + u32 inflight_hi; /* upper bound of inflight data range */ + u32 bw_probe_up_cnt; /* packets delivered per inflight_hi incr */ + u32 bw_probe_up_acks; /* packets (S)ACKed since inflight_hi incr */ + u32 probe_wait_us; /* PROBE_DOWN until next clock-driven probe */ + u32 prior_rcv_nxt; /* tp->rcv_nxt when CE state last changed */ + u32 ecn_eligible:1, /* sender can use ECN (RTT, handshake)? */ + ecn_alpha:9, /* EWMA delivered_ce/delivered; 0..256 */ + bw_probe_samples:1, /* rate samples reflect bw probing? */ + prev_probe_too_high:1, /* did last PROBE_UP go too high? */ + stopped_risky_probe:1, /* last PROBE_UP stopped due to risk? */ + rounds_since_probe:8, /* packet-timed rounds since probed bw */ + loss_round_start:1, /* loss_round_delivered round trip? */ + loss_in_round:1, /* loss marked in this round trip? */ + ecn_in_round:1, /* ECN marked in this round trip? */ + ack_phase:3, /* bbr_ack_phase: meaning of ACKs */ + loss_events_in_round:4,/* losses in STARTUP round */ + initialized:1; /* has bbr_init() been called? */ + u32 alpha_last_delivered; /* tp->delivered at alpha update */ + u32 alpha_last_delivered_ce; /* tp->delivered_ce at alpha update */ + + u8 unused_4; /* to preserve alignment */ + struct tcp_plb_state plb; }; -#define CYCLE_LEN 8 /* number of phases in a pacing gain cycle */ +struct bbr_context { + u32 sample_bw; +}; -/* Window length of bw filter (in rounds): */ -static const int bbr_bw_rtts = CYCLE_LEN + 2; /* Window length of min_rtt filter (in sec): */ static const u32 bbr_min_rtt_win_sec = 10; /* Minimum time (in ms) spent at bbr_cwnd_min_target in BBR_PROBE_RTT mode: */ static const u32 bbr_probe_rtt_mode_ms = 200; -/* Skip TSO below the following bandwidth (bits/sec): */ -static const int bbr_min_tso_rate = 1200000; +/* Window length of probe_rtt_min_us filter (in ms), and consequently the + * typical interval between PROBE_RTT mode entries. The default is 5000ms. + * Note that bbr_probe_rtt_win_ms must be <= bbr_min_rtt_win_sec * MSEC_PER_SEC + */ +static const u32 bbr_probe_rtt_win_ms = 5000; +/* Proportion of cwnd to estimated BDP in PROBE_RTT, in units of BBR_UNIT: */ +static const u32 bbr_probe_rtt_cwnd_gain = BBR_UNIT * 1 / 2; + +/* Use min_rtt to help adapt TSO burst size, with smaller min_rtt resulting + * in bigger TSO bursts. We cut the RTT-based allowance in half + * for every 2^9 usec (aka 512 us) of RTT, so that the RTT-based allowance + * is below 1500 bytes after 6 * ~500 usec = 3ms. + */ +static const u32 bbr_tso_rtt_shift = 9; /* Pace at ~1% below estimated bw, on average, to reduce queue at bottleneck. * In order to help drive the network toward lower queues and low latency while @@ -146,13 +207,15 @@ static const int bbr_min_tso_rate = 1200000; */ static const int bbr_pacing_margin_percent = 1; -/* We use a high_gain value of 2/ln(2) because it's the smallest pacing gain +/* We use a startup_pacing_gain of 4*ln(2) because it's the smallest value * that will allow a smoothly increasing pacing rate that will double each RTT * and send the same number of packets per RTT that an un-paced, slow-starting * Reno or CUBIC flow would: */ -static const int bbr_high_gain = BBR_UNIT * 2885 / 1000 + 1; -/* The pacing gain of 1/high_gain in BBR_DRAIN is calculated to typically drain +static const int bbr_startup_pacing_gain = BBR_UNIT * 277 / 100 + 1; +/* The gain for deriving startup cwnd: */ +static const int bbr_startup_cwnd_gain = BBR_UNIT * 2; +/* The pacing gain in BBR_DRAIN is calculated to typically drain * the queue created in BBR_STARTUP in a single round: */ static const int bbr_drain_gain = BBR_UNIT * 1000 / 2885; @@ -160,13 +223,17 @@ static const int bbr_drain_gain = BBR_UNIT * 1000 / 2885; static const int bbr_cwnd_gain = BBR_UNIT * 2; /* The pacing_gain values for the PROBE_BW gain cycle, to discover/share bw: */ static const int bbr_pacing_gain[] = { - BBR_UNIT * 5 / 4, /* probe for more available bw */ - BBR_UNIT * 3 / 4, /* drain queue and/or yield bw to other flows */ - BBR_UNIT, BBR_UNIT, BBR_UNIT, /* cruise at 1.0*bw to utilize pipe, */ - BBR_UNIT, BBR_UNIT, BBR_UNIT /* without creating excess queue... */ + BBR_UNIT * 5 / 4, /* UP: probe for more available bw */ + BBR_UNIT * 91 / 100, /* DOWN: drain queue and/or yield bw */ + BBR_UNIT, /* CRUISE: try to use pipe w/ some headroom */ + BBR_UNIT, /* REFILL: refill pipe to estimated 100% */ +}; +enum bbr_pacing_gain_phase { + BBR_BW_PROBE_UP = 0, /* push up inflight to probe for bw/vol */ + BBR_BW_PROBE_DOWN = 1, /* drain excess inflight from the queue */ + BBR_BW_PROBE_CRUISE = 2, /* use pipe, w/ headroom in queue/pipe */ + BBR_BW_PROBE_REFILL = 3, /* refill the pipe again to 100% */ }; -/* Randomize the starting gain cycling phase over N phases: */ -static const u32 bbr_cycle_rand = 7; /* Try to keep at least this many packets in flight, if things go smoothly. For * smooth functioning, a sliding window protocol ACKing every other packet @@ -174,24 +241,12 @@ static const u32 bbr_cycle_rand = 7; */ static const u32 bbr_cwnd_min_target = 4; -/* To estimate if BBR_STARTUP mode (i.e. high_gain) has filled pipe... */ +/* To estimate if BBR_STARTUP or BBR_BW_PROBE_UP has filled pipe... */ /* If bw has increased significantly (1.25x), there may be more bw available: */ static const u32 bbr_full_bw_thresh = BBR_UNIT * 5 / 4; /* But after 3 rounds w/o significant bw growth, estimate pipe is full: */ static const u32 bbr_full_bw_cnt = 3; -/* "long-term" ("LT") bandwidth estimator parameters... */ -/* The minimum number of rounds in an LT bw sampling interval: */ -static const u32 bbr_lt_intvl_min_rtts = 4; -/* If lost/delivered ratio > 20%, interval is "lossy" and we may be policed: */ -static const u32 bbr_lt_loss_thresh = 50; -/* If 2 intervals have a bw ratio <= 1/8, their bw is "consistent": */ -static const u32 bbr_lt_bw_ratio = BBR_UNIT / 8; -/* If 2 intervals have a bw diff <= 4 Kbit/sec their bw is "consistent": */ -static const u32 bbr_lt_bw_diff = 4000 / 8; -/* If we estimate we're policed, use lt_bw for this many round trips: */ -static const u32 bbr_lt_bw_max_rtts = 48; - /* Gain factor for adding extra_acked to target cwnd: */ static const int bbr_extra_acked_gain = BBR_UNIT; /* Window length of extra_acked window. */ @@ -201,8 +256,121 @@ static const u32 bbr_ack_epoch_acked_reset_thresh = 1U << 20; /* Time period for clamping cwnd increment due to ack aggregation */ static const u32 bbr_extra_acked_max_us = 100 * 1000; +/* Flags to control BBR ECN-related behavior... */ + +/* Ensure ACKs only ACK packets with consistent ECN CE status? */ +static const bool bbr_precise_ece_ack = true; + +/* Max RTT (in usec) at which to use sender-side ECN logic. + * Disabled when 0 (ECN allowed at any RTT). + */ +static const u32 bbr_ecn_max_rtt_us = 5000; + +/* On losses, scale down inflight and pacing rate by beta scaled by BBR_SCALE. + * No loss response when 0. + */ +static const u32 bbr_beta = BBR_UNIT * 30 / 100; + +/* Gain factor for ECN mark ratio samples, scaled by BBR_SCALE (1/16 = 6.25%) */ +static const u32 bbr_ecn_alpha_gain = BBR_UNIT * 1 / 16; + +/* The initial value for ecn_alpha; 1.0 allows a flow to respond quickly + * to congestion if the bottleneck is congested when the flow starts up. + */ +static const u32 bbr_ecn_alpha_init = BBR_UNIT; + +/* On ECN, cut inflight_lo to (1 - ecn_factor * ecn_alpha) scaled by BBR_SCALE. + * No ECN based bounding when 0. + */ +static const u32 bbr_ecn_factor = BBR_UNIT * 1 / 3; /* 1/3 = 33% */ + +/* Estimate bw probing has gone too far if CE ratio exceeds this threshold. + * Scaled by BBR_SCALE. Disabled when 0. + */ +static const u32 bbr_ecn_thresh = BBR_UNIT * 1 / 2; /* 1/2 = 50% */ + +/* If non-zero, if in a cycle with no losses but some ECN marks, after ECN + * clears then make the first round's increment to inflight_hi the following + * fraction of inflight_hi. + */ +static const u32 bbr_ecn_reprobe_gain = BBR_UNIT * 1 / 2; + +/* Estimate bw probing has gone too far if loss rate exceeds this level. */ +static const u32 bbr_loss_thresh = BBR_UNIT * 2 / 100; /* 2% loss */ + +/* Slow down for a packet loss recovered by TLP? */ +static const bool bbr_loss_probe_recovery = true; + +/* Exit STARTUP if number of loss marking events in a Recovery round is >= N, + * and loss rate is higher than bbr_loss_thresh. + * Disabled if 0. + */ +static const u32 bbr_full_loss_cnt = 6; + +/* Exit STARTUP if number of round trips with ECN mark rate above ecn_thresh + * meets this count. + */ +static const u32 bbr_full_ecn_cnt = 2; + +/* Fraction of unutilized headroom to try to leave in path upon high loss. */ +static const u32 bbr_inflight_headroom = BBR_UNIT * 15 / 100; + +/* How much do we increase cwnd_gain when probing for bandwidth in + * BBR_BW_PROBE_UP? This specifies the increment in units of + * BBR_UNIT/4. The default is 1, meaning 0.25. + * The min value is 0 (meaning 0.0); max is 3 (meaning 0.75). + */ +static const u32 bbr_bw_probe_cwnd_gain = 1; + +/* Max number of packet-timed rounds to wait before probing for bandwidth. If + * we want to tolerate 1% random loss per round, and not have this cut our + * inflight too much, we must probe for bw periodically on roughly this scale. + * If low, limits Reno/CUBIC coexistence; if high, limits loss tolerance. + * We aim to be fair with Reno/CUBIC up to a BDP of at least: + * BDP = 25Mbps * .030sec /(1514bytes) = 61.9 packets + */ +static const u32 bbr_bw_probe_max_rounds = 63; + +/* Max amount of randomness to inject in round counting for Reno-coexistence. + */ +static const u32 bbr_bw_probe_rand_rounds = 2; + +/* Use BBR-native probe time scale starting at this many usec. + * We aim to be fair with Reno/CUBIC up to an inter-loss time epoch of at least: + * BDP*RTT = 25Mbps * .030sec /(1514bytes) * 0.030sec = 1.9 secs + */ +static const u32 bbr_bw_probe_base_us = 2 * USEC_PER_SEC; /* 2 secs */ + +/* Use BBR-native probes spread over this many usec: */ +static const u32 bbr_bw_probe_rand_us = 1 * USEC_PER_SEC; /* 1 secs */ + +/* Use fast path if app-limited, no loss/ECN, and target cwnd was reached? */ +static const bool bbr_fast_path = true; + +/* Use fast ack mode? */ +static const bool bbr_fast_ack_mode = true; + +static u32 bbr_max_bw(const struct sock *sk); +static u32 bbr_bw(const struct sock *sk); +static void bbr_exit_probe_rtt(struct sock *sk); +static void bbr_reset_congestion_signals(struct sock *sk); +static void bbr_run_loss_probe_recovery(struct sock *sk); + static void bbr_check_probe_rtt_done(struct sock *sk); +/* This connection can use ECN if both endpoints have signaled ECN support in + * the handshake and the per-route settings indicated this is a + * shallow-threshold ECN environment, meaning both: + * (a) ECN CE marks indicate low-latency/shallow-threshold congestion, and + * (b) TCP endpoints provide precise ACKs that only ACK data segments + * with consistent ECN CE status + */ +static bool bbr_can_use_ecn(const struct sock *sk) +{ + return (tcp_sk(sk)->ecn_flags & TCP_ECN_OK) && + (tcp_sk(sk)->ecn_flags & TCP_ECN_LOW); +} + /* Do we estimate that STARTUP filled the pipe? */ static bool bbr_full_bw_reached(const struct sock *sk) { @@ -214,17 +382,17 @@ static bool bbr_full_bw_reached(const struct sock *sk) /* Return the windowed max recent bandwidth sample, in pkts/uS << BW_SCALE. */ static u32 bbr_max_bw(const struct sock *sk) { - struct bbr *bbr = inet_csk_ca(sk); + const struct bbr *bbr = inet_csk_ca(sk); - return minmax_get(&bbr->bw); + return max(bbr->bw_hi[0], bbr->bw_hi[1]); } /* Return the estimated bandwidth of the path, in pkts/uS << BW_SCALE. */ static u32 bbr_bw(const struct sock *sk) { - struct bbr *bbr = inet_csk_ca(sk); + const struct bbr *bbr = inet_csk_ca(sk); - return bbr->lt_use_bw ? bbr->lt_bw : bbr_max_bw(sk); + return min(bbr_max_bw(sk), bbr->bw_lo); } /* Return maximum extra acked in past k-2k round trips, @@ -241,15 +409,23 @@ static u16 bbr_extra_acked(const struct sock *sk) * The order here is chosen carefully to avoid overflow of u64. This should * work for input rates of up to 2.9Tbit/sec and gain of 2.89x. */ -static u64 bbr_rate_bytes_per_sec(struct sock *sk, u64 rate, int gain) +static u64 bbr_rate_bytes_per_sec(struct sock *sk, u64 rate, int gain, + int margin) { unsigned int mss = tcp_sk(sk)->mss_cache; rate *= mss; rate *= gain; rate >>= BBR_SCALE; - rate *= USEC_PER_SEC / 100 * (100 - bbr_pacing_margin_percent); - return rate >> BW_SCALE; + rate *= USEC_PER_SEC / 100 * (100 - margin); + rate >>= BW_SCALE; + rate = max(rate, 1ULL); + return rate; +} + +static u64 bbr_bw_bytes_per_sec(struct sock *sk, u64 rate) +{ + return bbr_rate_bytes_per_sec(sk, rate, BBR_UNIT, 0); } /* Convert a BBR bw and gain factor to a pacing rate in bytes per second. */ @@ -257,12 +433,13 @@ static unsigned long bbr_bw_to_pacing_rate(struct sock *sk, u32 bw, int gain) { u64 rate = bw; - rate = bbr_rate_bytes_per_sec(sk, rate, gain); + rate = bbr_rate_bytes_per_sec(sk, rate, gain, + bbr_pacing_margin_percent); rate = min_t(u64, rate, READ_ONCE(sk->sk_max_pacing_rate)); return rate; } -/* Initialize pacing rate to: high_gain * init_cwnd / RTT. */ +/* Initialize pacing rate to: startup_pacing_gain * init_cwnd / RTT. */ static void bbr_init_pacing_rate_from_rtt(struct sock *sk) { struct tcp_sock *tp = tcp_sk(sk); @@ -279,7 +456,8 @@ static void bbr_init_pacing_rate_from_rtt(struct sock *sk) bw = (u64)tcp_snd_cwnd(tp) * BW_UNIT; do_div(bw, rtt_us); WRITE_ONCE(sk->sk_pacing_rate, - bbr_bw_to_pacing_rate(sk, bw, bbr_high_gain)); + bbr_bw_to_pacing_rate(sk, bw, + bbr_param(sk, startup_pacing_gain))); } /* Pace using current bw estimate and a gain factor. */ @@ -295,26 +473,48 @@ static void bbr_set_pacing_rate(struct sock *sk, u32 bw, int gain) WRITE_ONCE(sk->sk_pacing_rate, rate); } -/* override sysctl_tcp_min_tso_segs */ -__bpf_kfunc static u32 bbr_min_tso_segs(struct sock *sk) +/* Return the number of segments BBR would like in a TSO/GSO skb, given a + * particular max gso size as a constraint. TODO: make this simpler and more + * consistent by switching bbr to just call tcp_tso_autosize(). + */ +static u32 bbr_tso_segs_generic(struct sock *sk, unsigned int mss_now, + u32 gso_max_size) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 segs, r; + u64 bytes; + + /* Budget a TSO/GSO burst size allowance based on bw (pacing_rate). */ + bytes = READ_ONCE(sk->sk_pacing_rate) >> READ_ONCE(sk->sk_pacing_shift); + + /* Budget a TSO/GSO burst size allowance based on min_rtt. For every + * K = 2^tso_rtt_shift microseconds of min_rtt, halve the burst. + * The min_rtt-based burst allowance is: 64 KBytes / 2^(min_rtt/K) + */ + if (bbr_param(sk, tso_rtt_shift)) { + r = bbr->min_rtt_us >> bbr_param(sk, tso_rtt_shift); + if (r < BITS_PER_TYPE(u32)) /* prevent undefined behavior */ + bytes += GSO_LEGACY_MAX_SIZE >> r; + } + + bytes = min_t(u32, bytes, gso_max_size - 1 - MAX_TCP_HEADER); + segs = max_t(u32, bytes / mss_now, + sock_net(sk)->ipv4.sysctl_tcp_min_tso_segs); + return segs; +} + +/* Custom tcp_tso_autosize() for BBR, used at transmit time to cap skb size. */ +__bpf_kfunc static u32 bbr_tso_segs(struct sock *sk, unsigned int mss_now) { - return READ_ONCE(sk->sk_pacing_rate) < (bbr_min_tso_rate >> 3) ? 1 : 2; + return bbr_tso_segs_generic(sk, mss_now, sk->sk_gso_max_size); } +/* Like bbr_tso_segs(), using mss_cache, ignoring driver's sk_gso_max_size. */ static u32 bbr_tso_segs_goal(struct sock *sk) { struct tcp_sock *tp = tcp_sk(sk); - u32 segs, bytes; - - /* Sort of tcp_tso_autosize() but ignoring - * driver provided sk_gso_max_size. - */ - bytes = min_t(unsigned long, - READ_ONCE(sk->sk_pacing_rate) >> READ_ONCE(sk->sk_pacing_shift), - GSO_LEGACY_MAX_SIZE - 1 - MAX_TCP_HEADER); - segs = max_t(u32, bytes / tp->mss_cache, bbr_min_tso_segs(sk)); - return min(segs, 0x7FU); + return bbr_tso_segs_generic(sk, tp->mss_cache, GSO_LEGACY_MAX_SIZE); } /* Save "last known good" cwnd so we can restore it after losses or PROBE_RTT */ @@ -334,7 +534,9 @@ __bpf_kfunc static void bbr_cwnd_event(struct sock *sk, enum tcp_ca_event event) struct tcp_sock *tp = tcp_sk(sk); struct bbr *bbr = inet_csk_ca(sk); - if (event == CA_EVENT_TX_START && tp->app_limited) { + if (event == CA_EVENT_TX_START) { + if (!tp->app_limited) + return; bbr->idle_restart = 1; bbr->ack_epoch_mstamp = tp->tcp_mstamp; bbr->ack_epoch_acked = 0; @@ -345,6 +547,16 @@ __bpf_kfunc static void bbr_cwnd_event(struct sock *sk, enum tcp_ca_event event) bbr_set_pacing_rate(sk, bbr_bw(sk), BBR_UNIT); else if (bbr->mode == BBR_PROBE_RTT) bbr_check_probe_rtt_done(sk); + } else if ((event == CA_EVENT_ECN_IS_CE || + event == CA_EVENT_ECN_NO_CE) && + bbr_can_use_ecn(sk) && + bbr_param(sk, precise_ece_ack)) { + u32 state = bbr->ce_state; + dctcp_ece_ack_update(sk, event, &bbr->prior_rcv_nxt, &state); + bbr->ce_state = state; + } else if (event == CA_EVENT_TLP_RECOVERY && + bbr_param(sk, loss_probe_recovery)) { + bbr_run_loss_probe_recovery(sk); } } @@ -367,10 +579,10 @@ static u32 bbr_bdp(struct sock *sk, u32 bw, int gain) * default. This should only happen when the connection is not using TCP * timestamps and has retransmitted all of the SYN/SYNACK/data packets * ACKed so far. In this case, an RTO can cut cwnd to 1, in which - * case we need to slow-start up toward something safe: TCP_INIT_CWND. + * case we need to slow-start up toward something safe: initial cwnd. */ if (unlikely(bbr->min_rtt_us == ~0U)) /* no valid RTT samples yet? */ - return TCP_INIT_CWND; /* be safe: cap at default initial cwnd*/ + return bbr->init_cwnd; /* be safe: cap at initial cwnd */ w = (u64)bw * bbr->min_rtt_us; @@ -387,23 +599,23 @@ static u32 bbr_bdp(struct sock *sk, u32 bw, int gain) * - one skb in sending host Qdisc, * - one skb in sending host TSO/GSO engine * - one skb being received by receiver host LRO/GRO/delayed-ACK engine - * Don't worry, at low rates (bbr_min_tso_rate) this won't bloat cwnd because - * in such cases tso_segs_goal is 1. The minimum cwnd is 4 packets, + * Don't worry, at low rates this won't bloat cwnd because + * in such cases tso_segs_goal is small. The minimum cwnd is 4 packets, * which allows 2 outstanding 2-packet sequences, to try to keep pipe * full even with ACK-every-other-packet delayed ACKs. */ static u32 bbr_quantization_budget(struct sock *sk, u32 cwnd) { struct bbr *bbr = inet_csk_ca(sk); + u32 tso_segs_goal; - /* Allow enough full-sized skbs in flight to utilize end systems. */ - cwnd += 3 * bbr_tso_segs_goal(sk); - - /* Reduce delayed ACKs by rounding up cwnd to the next even number. */ - cwnd = (cwnd + 1) & ~1U; + tso_segs_goal = 3 * bbr_tso_segs_goal(sk); + /* Allow enough full-sized skbs in flight to utilize end systems. */ + cwnd = max_t(u32, cwnd, tso_segs_goal); + cwnd = max_t(u32, cwnd, bbr_param(sk, cwnd_min_target)); /* Ensure gain cycling gets inflight above BDP even for small BDPs. */ - if (bbr->mode == BBR_PROBE_BW && bbr->cycle_idx == 0) + if (bbr->mode == BBR_PROBE_BW && bbr->cycle_idx == BBR_BW_PROBE_UP) cwnd += 2; return cwnd; @@ -458,10 +670,10 @@ static u32 bbr_ack_aggregation_cwnd(struct sock *sk) { u32 max_aggr_cwnd, aggr_cwnd = 0; - if (bbr_extra_acked_gain && bbr_full_bw_reached(sk)) { + if (bbr_param(sk, extra_acked_gain)) { max_aggr_cwnd = ((u64)bbr_bw(sk) * bbr_extra_acked_max_us) / BW_UNIT; - aggr_cwnd = (bbr_extra_acked_gain * bbr_extra_acked(sk)) + aggr_cwnd = (bbr_param(sk, extra_acked_gain) * bbr_extra_acked(sk)) >> BBR_SCALE; aggr_cwnd = min(aggr_cwnd, max_aggr_cwnd); } @@ -469,66 +681,27 @@ static u32 bbr_ack_aggregation_cwnd(struct sock *sk) return aggr_cwnd; } -/* An optimization in BBR to reduce losses: On the first round of recovery, we - * follow the packet conservation principle: send P packets per P packets acked. - * After that, we slow-start and send at most 2*P packets per P packets acked. - * After recovery finishes, or upon undo, we restore the cwnd we had when - * recovery started (capped by the target cwnd based on estimated BDP). - * - * TODO(ycheng/ncardwell): implement a rate-based approach. - */ -static bool bbr_set_cwnd_to_recover_or_restore( - struct sock *sk, const struct rate_sample *rs, u32 acked, u32 *new_cwnd) +/* Returns the cwnd for PROBE_RTT mode. */ +static u32 bbr_probe_rtt_cwnd(struct sock *sk) { - struct tcp_sock *tp = tcp_sk(sk); - struct bbr *bbr = inet_csk_ca(sk); - u8 prev_state = bbr->prev_ca_state, state = inet_csk(sk)->icsk_ca_state; - u32 cwnd = tcp_snd_cwnd(tp); - - /* An ACK for P pkts should release at most 2*P packets. We do this - * in two steps. First, here we deduct the number of lost packets. - * Then, in bbr_set_cwnd() we slow start up toward the target cwnd. - */ - if (rs->losses > 0) - cwnd = max_t(s32, cwnd - rs->losses, 1); - - if (state == TCP_CA_Recovery && prev_state != TCP_CA_Recovery) { - /* Starting 1st round of Recovery, so do packet conservation. */ - bbr->packet_conservation = 1; - bbr->next_rtt_delivered = tp->delivered; /* start round now */ - /* Cut unused cwnd from app behavior, TSQ, or TSO deferral: */ - cwnd = tcp_packets_in_flight(tp) + acked; - } else if (prev_state >= TCP_CA_Recovery && state < TCP_CA_Recovery) { - /* Exiting loss recovery; restore cwnd saved before recovery. */ - cwnd = max(cwnd, bbr->prior_cwnd); - bbr->packet_conservation = 0; - } - bbr->prev_ca_state = state; - - if (bbr->packet_conservation) { - *new_cwnd = max(cwnd, tcp_packets_in_flight(tp) + acked); - return true; /* yes, using packet conservation */ - } - *new_cwnd = cwnd; - return false; + return max_t(u32, bbr_param(sk, cwnd_min_target), + bbr_bdp(sk, bbr_bw(sk), bbr_param(sk, probe_rtt_cwnd_gain))); } /* Slow-start up toward target cwnd (if bw estimate is growing, or packet loss * has drawn us down below target), or snap down to target if we're above it. */ static void bbr_set_cwnd(struct sock *sk, const struct rate_sample *rs, - u32 acked, u32 bw, int gain) + u32 acked, u32 bw, int gain, u32 cwnd, + struct bbr_context *ctx) { struct tcp_sock *tp = tcp_sk(sk); struct bbr *bbr = inet_csk_ca(sk); - u32 cwnd = tcp_snd_cwnd(tp), target_cwnd = 0; + u32 target_cwnd = 0; if (!acked) goto done; /* no packet fully ACKed; just apply caps */ - if (bbr_set_cwnd_to_recover_or_restore(sk, rs, acked, &cwnd)) - goto done; - target_cwnd = bbr_bdp(sk, bw, gain); /* Increment the cwnd to account for excess ACKed data that seems @@ -537,74 +710,26 @@ static void bbr_set_cwnd(struct sock *sk, const struct rate_sample *rs, target_cwnd += bbr_ack_aggregation_cwnd(sk); target_cwnd = bbr_quantization_budget(sk, target_cwnd); - /* If we're below target cwnd, slow start cwnd toward target cwnd. */ - if (bbr_full_bw_reached(sk)) /* only cut cwnd if we filled the pipe */ - cwnd = min(cwnd + acked, target_cwnd); - else if (cwnd < target_cwnd || tp->delivered < TCP_INIT_CWND) - cwnd = cwnd + acked; - cwnd = max(cwnd, bbr_cwnd_min_target); + /* Update cwnd and enable fast path if cwnd reaches target_cwnd. */ + bbr->try_fast_path = 0; + if (bbr_full_bw_reached(sk)) { /* only cut cwnd if we filled the pipe */ + cwnd += acked; + if (cwnd >= target_cwnd) { + cwnd = target_cwnd; + bbr->try_fast_path = 1; + } + } else if (cwnd < target_cwnd || cwnd < 2 * bbr->init_cwnd) { + cwnd += acked; + } else { + bbr->try_fast_path = 1; + } + cwnd = max_t(u32, cwnd, bbr_param(sk, cwnd_min_target)); done: - tcp_snd_cwnd_set(tp, min(cwnd, tp->snd_cwnd_clamp)); /* apply global cap */ + tcp_snd_cwnd_set(tp, min(cwnd, tp->snd_cwnd_clamp)); /* global cap */ if (bbr->mode == BBR_PROBE_RTT) /* drain queue, refresh min_rtt */ - tcp_snd_cwnd_set(tp, min(tcp_snd_cwnd(tp), bbr_cwnd_min_target)); -} - -/* End cycle phase if it's time and/or we hit the phase's in-flight target. */ -static bool bbr_is_next_cycle_phase(struct sock *sk, - const struct rate_sample *rs) -{ - struct tcp_sock *tp = tcp_sk(sk); - struct bbr *bbr = inet_csk_ca(sk); - bool is_full_length = - tcp_stamp_us_delta(tp->delivered_mstamp, bbr->cycle_mstamp) > - bbr->min_rtt_us; - u32 inflight, bw; - - /* The pacing_gain of 1.0 paces at the estimated bw to try to fully - * use the pipe without increasing the queue. - */ - if (bbr->pacing_gain == BBR_UNIT) - return is_full_length; /* just use wall clock time */ - - inflight = bbr_packets_in_net_at_edt(sk, rs->prior_in_flight); - bw = bbr_max_bw(sk); - - /* A pacing_gain > 1.0 probes for bw by trying to raise inflight to at - * least pacing_gain*BDP; this may take more than min_rtt if min_rtt is - * small (e.g. on a LAN). We do not persist if packets are lost, since - * a path with small buffers may not hold that much. - */ - if (bbr->pacing_gain > BBR_UNIT) - return is_full_length && - (rs->losses || /* perhaps pacing_gain*BDP won't fit */ - inflight >= bbr_inflight(sk, bw, bbr->pacing_gain)); - - /* A pacing_gain < 1.0 tries to drain extra queue we added if bw - * probing didn't find more bw. If inflight falls to match BDP then we - * estimate queue is drained; persisting would underutilize the pipe. - */ - return is_full_length || - inflight <= bbr_inflight(sk, bw, BBR_UNIT); -} - -static void bbr_advance_cycle_phase(struct sock *sk) -{ - struct tcp_sock *tp = tcp_sk(sk); - struct bbr *bbr = inet_csk_ca(sk); - - bbr->cycle_idx = (bbr->cycle_idx + 1) & (CYCLE_LEN - 1); - bbr->cycle_mstamp = tp->delivered_mstamp; -} - -/* Gain cycling: cycle pacing gain to converge to fair share of available bw. */ -static void bbr_update_cycle_phase(struct sock *sk, - const struct rate_sample *rs) -{ - struct bbr *bbr = inet_csk_ca(sk); - - if (bbr->mode == BBR_PROBE_BW && bbr_is_next_cycle_phase(sk, rs)) - bbr_advance_cycle_phase(sk); + tcp_snd_cwnd_set(tp, min_t(u32, tcp_snd_cwnd(tp), + bbr_probe_rtt_cwnd(sk))); } static void bbr_reset_startup_mode(struct sock *sk) @@ -614,191 +739,49 @@ static void bbr_reset_startup_mode(struct sock *sk) bbr->mode = BBR_STARTUP; } -static void bbr_reset_probe_bw_mode(struct sock *sk) -{ - struct bbr *bbr = inet_csk_ca(sk); - - bbr->mode = BBR_PROBE_BW; - bbr->cycle_idx = CYCLE_LEN - 1 - get_random_u32_below(bbr_cycle_rand); - bbr_advance_cycle_phase(sk); /* flip to next phase of gain cycle */ -} - -static void bbr_reset_mode(struct sock *sk) -{ - if (!bbr_full_bw_reached(sk)) - bbr_reset_startup_mode(sk); - else - bbr_reset_probe_bw_mode(sk); -} - -/* Start a new long-term sampling interval. */ -static void bbr_reset_lt_bw_sampling_interval(struct sock *sk) -{ - struct tcp_sock *tp = tcp_sk(sk); - struct bbr *bbr = inet_csk_ca(sk); - - bbr->lt_last_stamp = div_u64(tp->delivered_mstamp, USEC_PER_MSEC); - bbr->lt_last_delivered = tp->delivered; - bbr->lt_last_lost = tp->lost; - bbr->lt_rtt_cnt = 0; -} - -/* Completely reset long-term bandwidth sampling. */ -static void bbr_reset_lt_bw_sampling(struct sock *sk) -{ - struct bbr *bbr = inet_csk_ca(sk); - - bbr->lt_bw = 0; - bbr->lt_use_bw = 0; - bbr->lt_is_sampling = false; - bbr_reset_lt_bw_sampling_interval(sk); -} - -/* Long-term bw sampling interval is done. Estimate whether we're policed. */ -static void bbr_lt_bw_interval_done(struct sock *sk, u32 bw) -{ - struct bbr *bbr = inet_csk_ca(sk); - u32 diff; - - if (bbr->lt_bw) { /* do we have bw from a previous interval? */ - /* Is new bw close to the lt_bw from the previous interval? */ - diff = abs(bw - bbr->lt_bw); - if ((diff * BBR_UNIT <= bbr_lt_bw_ratio * bbr->lt_bw) || - (bbr_rate_bytes_per_sec(sk, diff, BBR_UNIT) <= - bbr_lt_bw_diff)) { - /* All criteria are met; estimate we're policed. */ - bbr->lt_bw = (bw + bbr->lt_bw) >> 1; /* avg 2 intvls */ - bbr->lt_use_bw = 1; - bbr->pacing_gain = BBR_UNIT; /* try to avoid drops */ - bbr->lt_rtt_cnt = 0; - return; - } - } - bbr->lt_bw = bw; - bbr_reset_lt_bw_sampling_interval(sk); -} - -/* Token-bucket traffic policers are common (see "An Internet-Wide Analysis of - * Traffic Policing", SIGCOMM 2016). BBR detects token-bucket policers and - * explicitly models their policed rate, to reduce unnecessary losses. We - * estimate that we're policed if we see 2 consecutive sampling intervals with - * consistent throughput and high packet loss. If we think we're being policed, - * set lt_bw to the "long-term" average delivery rate from those 2 intervals. +/* See if we have reached next round trip. Upon start of the new round, + * returns packets delivered since previous round start plus this ACK. */ -static void bbr_lt_bw_sampling(struct sock *sk, const struct rate_sample *rs) -{ - struct tcp_sock *tp = tcp_sk(sk); - struct bbr *bbr = inet_csk_ca(sk); - u32 lost, delivered; - u64 bw; - u32 t; - - if (bbr->lt_use_bw) { /* already using long-term rate, lt_bw? */ - if (bbr->mode == BBR_PROBE_BW && bbr->round_start && - ++bbr->lt_rtt_cnt >= bbr_lt_bw_max_rtts) { - bbr_reset_lt_bw_sampling(sk); /* stop using lt_bw */ - bbr_reset_probe_bw_mode(sk); /* restart gain cycling */ - } - return; - } - - /* Wait for the first loss before sampling, to let the policer exhaust - * its tokens and estimate the steady-state rate allowed by the policer. - * Starting samples earlier includes bursts that over-estimate the bw. - */ - if (!bbr->lt_is_sampling) { - if (!rs->losses) - return; - bbr_reset_lt_bw_sampling_interval(sk); - bbr->lt_is_sampling = true; - } - - /* To avoid underestimates, reset sampling if we run out of data. */ - if (rs->is_app_limited) { - bbr_reset_lt_bw_sampling(sk); - return; - } - - if (bbr->round_start) - bbr->lt_rtt_cnt++; /* count round trips in this interval */ - if (bbr->lt_rtt_cnt < bbr_lt_intvl_min_rtts) - return; /* sampling interval needs to be longer */ - if (bbr->lt_rtt_cnt > 4 * bbr_lt_intvl_min_rtts) { - bbr_reset_lt_bw_sampling(sk); /* interval is too long */ - return; - } - - /* End sampling interval when a packet is lost, so we estimate the - * policer tokens were exhausted. Stopping the sampling before the - * tokens are exhausted under-estimates the policed rate. - */ - if (!rs->losses) - return; - - /* Calculate packets lost and delivered in sampling interval. */ - lost = tp->lost - bbr->lt_last_lost; - delivered = tp->delivered - bbr->lt_last_delivered; - /* Is loss rate (lost/delivered) >= lt_loss_thresh? If not, wait. */ - if (!delivered || (lost << BBR_SCALE) < bbr_lt_loss_thresh * delivered) - return; - - /* Find average delivery rate in this sampling interval. */ - t = div_u64(tp->delivered_mstamp, USEC_PER_MSEC) - bbr->lt_last_stamp; - if ((s32)t < 1) - return; /* interval is less than one ms, so wait */ - /* Check if can multiply without overflow */ - if (t >= ~0U / USEC_PER_MSEC) { - bbr_reset_lt_bw_sampling(sk); /* interval too long; reset */ - return; - } - t *= USEC_PER_MSEC; - bw = (u64)delivered * BW_UNIT; - do_div(bw, t); - bbr_lt_bw_interval_done(sk, bw); -} - -/* Estimate the bandwidth based on how fast packets are delivered */ -static void bbr_update_bw(struct sock *sk, const struct rate_sample *rs) +static u32 bbr_update_round_start(struct sock *sk, + const struct rate_sample *rs, struct bbr_context *ctx) { struct tcp_sock *tp = tcp_sk(sk); struct bbr *bbr = inet_csk_ca(sk); - u64 bw; + u32 round_delivered = 0; bbr->round_start = 0; - if (rs->delivered < 0 || rs->interval_us <= 0) - return; /* Not a valid observation */ /* See if we've reached the next RTT */ - if (!before(rs->prior_delivered, bbr->next_rtt_delivered)) { + if (rs->interval_us > 0 && + !before(rs->prior_delivered, bbr->next_rtt_delivered)) { + round_delivered = tp->delivered - bbr->next_rtt_delivered; bbr->next_rtt_delivered = tp->delivered; - bbr->rtt_cnt++; bbr->round_start = 1; - bbr->packet_conservation = 0; } + return round_delivered; +} - bbr_lt_bw_sampling(sk, rs); +/* Calculate the bandwidth based on how fast packets are delivered */ +static void bbr_calculate_bw_sample(struct sock *sk, + const struct rate_sample *rs, struct bbr_context *ctx) +{ + u64 bw = 0; /* Divide delivered by the interval to find a (lower bound) bottleneck * bandwidth sample. Delivered is in packets and interval_us in uS and * ratio will be <<1 for most connections. So delivered is first scaled. + * Round up to allow growth at low rates, even with integer division. */ - bw = div64_long((u64)rs->delivered * BW_UNIT, rs->interval_us); - - /* If this sample is application-limited, it is likely to have a very - * low delivered count that represents application behavior rather than - * the available network rate. Such a sample could drag down estimated - * bw, causing needless slow-down. Thus, to continue to send at the - * last measured network rate, we filter out app-limited samples unless - * they describe the path bw at least as well as our bw model. - * - * So the goal during app-limited phase is to proceed with the best - * network rate no matter how long. We automatically leave this - * phase when app writes faster than the network can deliver :) - */ - if (!rs->is_app_limited || bw >= bbr_max_bw(sk)) { - /* Incorporate new sample into our max bw filter. */ - minmax_running_max(&bbr->bw, bbr_bw_rtts, bbr->rtt_cnt, bw); + if (rs->interval_us > 0) { + if (WARN_ONCE(rs->delivered < 0, + "negative delivered: %d interval_us: %ld\n", + rs->delivered, rs->interval_us)) + return; + + bw = DIV_ROUND_UP_ULL((u64)rs->delivered * BW_UNIT, rs->interval_us); } + + ctx->sample_bw = bw; } /* Estimates the windowed max degree of ack aggregation. @@ -812,7 +795,7 @@ static void bbr_update_bw(struct sock *sk, const struct rate_sample *rs) * * Max extra_acked is clamped by cwnd and bw * bbr_extra_acked_max_us (100 ms). * Max filter is an approximate sliding window of 5-10 (packet timed) round - * trips. + * trips for non-startup phase, and 1-2 round trips for startup. */ static void bbr_update_ack_aggregation(struct sock *sk, const struct rate_sample *rs) @@ -820,15 +803,19 @@ static void bbr_update_ack_aggregation(struct sock *sk, u32 epoch_us, expected_acked, extra_acked; struct bbr *bbr = inet_csk_ca(sk); struct tcp_sock *tp = tcp_sk(sk); + u32 extra_acked_win_rtts_thresh = bbr_param(sk, extra_acked_win_rtts); - if (!bbr_extra_acked_gain || rs->acked_sacked <= 0 || + if (!bbr_param(sk, extra_acked_gain) || rs->acked_sacked <= 0 || rs->delivered < 0 || rs->interval_us <= 0) return; if (bbr->round_start) { bbr->extra_acked_win_rtts = min(0x1F, bbr->extra_acked_win_rtts + 1); - if (bbr->extra_acked_win_rtts >= bbr_extra_acked_win_rtts) { + if (!bbr_full_bw_reached(sk)) + extra_acked_win_rtts_thresh = 1; + if (bbr->extra_acked_win_rtts >= + extra_acked_win_rtts_thresh) { bbr->extra_acked_win_rtts = 0; bbr->extra_acked_win_idx = bbr->extra_acked_win_idx ? 0 : 1; @@ -862,49 +849,6 @@ static void bbr_update_ack_aggregation(struct sock *sk, bbr->extra_acked[bbr->extra_acked_win_idx] = extra_acked; } -/* Estimate when the pipe is full, using the change in delivery rate: BBR - * estimates that STARTUP filled the pipe if the estimated bw hasn't changed by - * at least bbr_full_bw_thresh (25%) after bbr_full_bw_cnt (3) non-app-limited - * rounds. Why 3 rounds: 1: rwin autotuning grows the rwin, 2: we fill the - * higher rwin, 3: we get higher delivery rate samples. Or transient - * cross-traffic or radio noise can go away. CUBIC Hystart shares a similar - * design goal, but uses delay and inter-ACK spacing instead of bandwidth. - */ -static void bbr_check_full_bw_reached(struct sock *sk, - const struct rate_sample *rs) -{ - struct bbr *bbr = inet_csk_ca(sk); - u32 bw_thresh; - - if (bbr_full_bw_reached(sk) || !bbr->round_start || rs->is_app_limited) - return; - - bw_thresh = (u64)bbr->full_bw * bbr_full_bw_thresh >> BBR_SCALE; - if (bbr_max_bw(sk) >= bw_thresh) { - bbr->full_bw = bbr_max_bw(sk); - bbr->full_bw_cnt = 0; - return; - } - ++bbr->full_bw_cnt; - bbr->full_bw_reached = bbr->full_bw_cnt >= bbr_full_bw_cnt; -} - -/* If pipe is probably full, drain the queue and then enter steady-state. */ -static void bbr_check_drain(struct sock *sk, const struct rate_sample *rs) -{ - struct bbr *bbr = inet_csk_ca(sk); - - if (bbr->mode == BBR_STARTUP && bbr_full_bw_reached(sk)) { - bbr->mode = BBR_DRAIN; /* drain queue we created */ - tcp_sk(sk)->snd_ssthresh = - bbr_inflight(sk, bbr_max_bw(sk), BBR_UNIT); - } /* fall through to check if in-flight is already small: */ - if (bbr->mode == BBR_DRAIN && - bbr_packets_in_net_at_edt(sk, tcp_packets_in_flight(tcp_sk(sk))) <= - bbr_inflight(sk, bbr_max_bw(sk), BBR_UNIT)) - bbr_reset_probe_bw_mode(sk); /* we estimate queue is drained */ -} - static void bbr_check_probe_rtt_done(struct sock *sk) { struct tcp_sock *tp = tcp_sk(sk); @@ -914,9 +858,9 @@ static void bbr_check_probe_rtt_done(struct sock *sk) after(tcp_jiffies32, bbr->probe_rtt_done_stamp))) return; - bbr->min_rtt_stamp = tcp_jiffies32; /* wait a while until PROBE_RTT */ + bbr->probe_rtt_min_stamp = tcp_jiffies32; /* schedule next PROBE_RTT */ tcp_snd_cwnd_set(tp, max(tcp_snd_cwnd(tp), bbr->prior_cwnd)); - bbr_reset_mode(sk); + bbr_exit_probe_rtt(sk); } /* The goal of PROBE_RTT mode is to have BBR flows cooperatively and @@ -942,23 +886,35 @@ static void bbr_update_min_rtt(struct sock *sk, const struct rate_sample *rs) { struct tcp_sock *tp = tcp_sk(sk); struct bbr *bbr = inet_csk_ca(sk); - bool filter_expired; + bool probe_rtt_expired, min_rtt_expired; + u32 expire; - /* Track min RTT seen in the min_rtt_win_sec filter window: */ - filter_expired = after(tcp_jiffies32, - bbr->min_rtt_stamp + bbr_min_rtt_win_sec * HZ); + /* Track min RTT in probe_rtt_win_ms to time next PROBE_RTT state. */ + expire = bbr->probe_rtt_min_stamp + + msecs_to_jiffies(bbr_param(sk, probe_rtt_win_ms)); + probe_rtt_expired = after(tcp_jiffies32, expire); if (rs->rtt_us >= 0 && - (rs->rtt_us < bbr->min_rtt_us || - (filter_expired && !rs->is_ack_delayed))) { - bbr->min_rtt_us = rs->rtt_us; - bbr->min_rtt_stamp = tcp_jiffies32; + (rs->rtt_us < bbr->probe_rtt_min_us || + (probe_rtt_expired && !rs->is_ack_delayed))) { + bbr->probe_rtt_min_us = rs->rtt_us; + bbr->probe_rtt_min_stamp = tcp_jiffies32; + } + /* Track min RTT seen in the min_rtt_win_sec filter window: */ + expire = bbr->min_rtt_stamp + bbr_param(sk, min_rtt_win_sec) * HZ; + min_rtt_expired = after(tcp_jiffies32, expire); + if (bbr->probe_rtt_min_us <= bbr->min_rtt_us || + min_rtt_expired) { + bbr->min_rtt_us = bbr->probe_rtt_min_us; + bbr->min_rtt_stamp = bbr->probe_rtt_min_stamp; } - if (bbr_probe_rtt_mode_ms > 0 && filter_expired && + if (bbr_param(sk, probe_rtt_mode_ms) > 0 && probe_rtt_expired && !bbr->idle_restart && bbr->mode != BBR_PROBE_RTT) { bbr->mode = BBR_PROBE_RTT; /* dip, drain queue */ bbr_save_cwnd(sk); /* note cwnd so we can restore it */ bbr->probe_rtt_done_stamp = 0; + bbr->ack_phase = BBR_ACKS_PROBE_STOPPING; + bbr->next_rtt_delivered = tp->delivered; } if (bbr->mode == BBR_PROBE_RTT) { @@ -967,9 +923,9 @@ static void bbr_update_min_rtt(struct sock *sk, const struct rate_sample *rs) (tp->delivered + tcp_packets_in_flight(tp)) ? : 1; /* Maintain min packets in flight for max(200 ms, 1 round). */ if (!bbr->probe_rtt_done_stamp && - tcp_packets_in_flight(tp) <= bbr_cwnd_min_target) { + tcp_packets_in_flight(tp) <= bbr_probe_rtt_cwnd(sk)) { bbr->probe_rtt_done_stamp = tcp_jiffies32 + - msecs_to_jiffies(bbr_probe_rtt_mode_ms); + msecs_to_jiffies(bbr_param(sk, probe_rtt_mode_ms)); bbr->probe_rtt_round_done = 0; bbr->next_rtt_delivered = tp->delivered; } else if (bbr->probe_rtt_done_stamp) { @@ -990,18 +946,20 @@ static void bbr_update_gains(struct sock *sk) switch (bbr->mode) { case BBR_STARTUP: - bbr->pacing_gain = bbr_high_gain; - bbr->cwnd_gain = bbr_high_gain; + bbr->pacing_gain = bbr_param(sk, startup_pacing_gain); + bbr->cwnd_gain = bbr_param(sk, startup_cwnd_gain); break; case BBR_DRAIN: - bbr->pacing_gain = bbr_drain_gain; /* slow, to drain */ - bbr->cwnd_gain = bbr_high_gain; /* keep cwnd */ + bbr->pacing_gain = bbr_param(sk, drain_gain); /* slow, to drain */ + bbr->cwnd_gain = bbr_param(sk, startup_cwnd_gain); /* keep cwnd */ break; case BBR_PROBE_BW: - bbr->pacing_gain = (bbr->lt_use_bw ? - BBR_UNIT : - bbr_pacing_gain[bbr->cycle_idx]); - bbr->cwnd_gain = bbr_cwnd_gain; + bbr->pacing_gain = bbr_pacing_gain[bbr->cycle_idx]; + bbr->cwnd_gain = bbr_param(sk, cwnd_gain); + if (bbr_param(sk, bw_probe_cwnd_gain) && + bbr->cycle_idx == BBR_BW_PROBE_UP) + bbr->cwnd_gain += + BBR_UNIT * bbr_param(sk, bw_probe_cwnd_gain) / 4; break; case BBR_PROBE_RTT: bbr->pacing_gain = BBR_UNIT; @@ -1013,144 +971,1387 @@ static void bbr_update_gains(struct sock *sk) } } -static void bbr_update_model(struct sock *sk, const struct rate_sample *rs) +__bpf_kfunc static u32 bbr_sndbuf_expand(struct sock *sk) { - bbr_update_bw(sk, rs); - bbr_update_ack_aggregation(sk, rs); - bbr_update_cycle_phase(sk, rs); - bbr_check_full_bw_reached(sk, rs); - bbr_check_drain(sk, rs); - bbr_update_min_rtt(sk, rs); - bbr_update_gains(sk); + /* Provision 3 * cwnd since BBR may slow-start even during recovery. */ + return 3; } -__bpf_kfunc static void bbr_main(struct sock *sk, u32 ack, int flag, const struct rate_sample *rs) +/* Incorporate a new bw sample into the current window of our max filter. */ +static void bbr_take_max_bw_sample(struct sock *sk, u32 bw) { struct bbr *bbr = inet_csk_ca(sk); - u32 bw; - bbr_update_model(sk, rs); - - bw = bbr_bw(sk); - bbr_set_pacing_rate(sk, bw, bbr->pacing_gain); - bbr_set_cwnd(sk, rs, rs->acked_sacked, bw, bbr->cwnd_gain); + bbr->bw_hi[1] = max(bw, bbr->bw_hi[1]); } -__bpf_kfunc static void bbr_init(struct sock *sk) +/* Keep max of last 1-2 cycles. Each PROBE_BW cycle, flip filter window. */ +static void bbr_advance_max_bw_filter(struct sock *sk) { - struct tcp_sock *tp = tcp_sk(sk); struct bbr *bbr = inet_csk_ca(sk); - bbr->prior_cwnd = 0; - tp->snd_ssthresh = TCP_INFINITE_SSTHRESH; - bbr->rtt_cnt = 0; - bbr->next_rtt_delivered = tp->delivered; - bbr->prev_ca_state = TCP_CA_Open; - bbr->packet_conservation = 0; - - bbr->probe_rtt_done_stamp = 0; - bbr->probe_rtt_round_done = 0; - bbr->min_rtt_us = tcp_min_rtt(tp); - bbr->min_rtt_stamp = tcp_jiffies32; - - minmax_reset(&bbr->bw, bbr->rtt_cnt, 0); /* init max bw to 0 */ + if (!bbr->bw_hi[1]) + return; /* no samples in this window; remember old window */ + bbr->bw_hi[0] = bbr->bw_hi[1]; + bbr->bw_hi[1] = 0; +} - bbr->has_seen_rtt = 0; - bbr_init_pacing_rate_from_rtt(sk); +/* Reset the estimator for reaching full bandwidth based on bw plateau. */ +static void bbr_reset_full_bw(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); - bbr->round_start = 0; - bbr->idle_restart = 0; - bbr->full_bw_reached = 0; bbr->full_bw = 0; bbr->full_bw_cnt = 0; - bbr->cycle_mstamp = 0; - bbr->cycle_idx = 0; - bbr_reset_lt_bw_sampling(sk); - bbr_reset_startup_mode(sk); + bbr->full_bw_now = 0; +} - bbr->ack_epoch_mstamp = tp->tcp_mstamp; - bbr->ack_epoch_acked = 0; - bbr->extra_acked_win_rtts = 0; - bbr->extra_acked_win_idx = 0; - bbr->extra_acked[0] = 0; - bbr->extra_acked[1] = 0; +/* How much do we want in flight? Our BDP, unless congestion cut cwnd. */ +static u32 bbr_target_inflight(struct sock *sk) +{ + u32 bdp = bbr_inflight(sk, bbr_bw(sk), BBR_UNIT); - cmpxchg(&sk->sk_pacing_status, SK_PACING_NONE, SK_PACING_NEEDED); + return min(bdp, tcp_sk(sk)->snd_cwnd); } -__bpf_kfunc static u32 bbr_sndbuf_expand(struct sock *sk) +static bool bbr_is_probing_bandwidth(struct sock *sk) { - /* Provision 3 * cwnd since BBR may slow-start even during recovery. */ - return 3; + struct bbr *bbr = inet_csk_ca(sk); + + return (bbr->mode == BBR_STARTUP) || + (bbr->mode == BBR_PROBE_BW && + (bbr->cycle_idx == BBR_BW_PROBE_REFILL || + bbr->cycle_idx == BBR_BW_PROBE_UP)); +} + +/* Has the given amount of time elapsed since we marked the phase start? */ +static bool bbr_has_elapsed_in_phase(const struct sock *sk, u32 interval_us) +{ + const struct tcp_sock *tp = tcp_sk(sk); + const struct bbr *bbr = inet_csk_ca(sk); + + return tcp_stamp_us_delta(tp->tcp_mstamp, + bbr->cycle_mstamp + interval_us) > 0; +} + +static void bbr_handle_queue_too_high_in_startup(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 bdp; /* estimated BDP in packets, with quantization budget */ + + bbr->full_bw_reached = 1; + + bdp = bbr_inflight(sk, bbr_max_bw(sk), BBR_UNIT); + bbr->inflight_hi = max(bdp, bbr->inflight_latest); +} + +/* Exit STARTUP upon N consecutive rounds with ECN mark rate > ecn_thresh. */ +static void bbr_check_ecn_too_high_in_startup(struct sock *sk, u32 ce_ratio) +{ + struct bbr *bbr = inet_csk_ca(sk); + + if (bbr_full_bw_reached(sk) || !bbr->ecn_eligible || + !bbr_param(sk, full_ecn_cnt) || !bbr_param(sk, ecn_thresh)) + return; + + if (ce_ratio >= bbr_param(sk, ecn_thresh)) + bbr->startup_ecn_rounds++; + else + bbr->startup_ecn_rounds = 0; + + if (bbr->startup_ecn_rounds >= bbr_param(sk, full_ecn_cnt)) { + bbr_handle_queue_too_high_in_startup(sk); + return; + } +} + +/* Updates ecn_alpha and returns ce_ratio. -1 if not available. */ +static int bbr_update_ecn_alpha(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct net *net = sock_net(sk); + struct bbr *bbr = inet_csk_ca(sk); + s32 delivered, delivered_ce; + u64 alpha, ce_ratio; + u32 gain; + bool want_ecn_alpha; + + /* See if we should use ECN sender logic for this connection. */ + if (!bbr->ecn_eligible && bbr_can_use_ecn(sk) && + !!bbr_param(sk, ecn_factor) && + (bbr->min_rtt_us <= bbr_ecn_max_rtt_us || + !bbr_ecn_max_rtt_us)) + bbr->ecn_eligible = 1; + + /* Skip updating alpha only if not ECN-eligible and PLB is disabled. */ + want_ecn_alpha = (bbr->ecn_eligible || + (bbr_can_use_ecn(sk) && + READ_ONCE(net->ipv4.sysctl_tcp_plb_enabled))); + if (!want_ecn_alpha) + return -1; + + delivered = tp->delivered - bbr->alpha_last_delivered; + delivered_ce = tp->delivered_ce - bbr->alpha_last_delivered_ce; + + if (delivered == 0 || /* avoid divide by zero */ + WARN_ON_ONCE(delivered < 0 || delivered_ce < 0)) /* backwards? */ + return -1; + + BUILD_BUG_ON(BBR_SCALE != TCP_PLB_SCALE); + ce_ratio = (u64)delivered_ce << BBR_SCALE; + do_div(ce_ratio, delivered); + + gain = bbr_param(sk, ecn_alpha_gain); + alpha = ((BBR_UNIT - gain) * bbr->ecn_alpha) >> BBR_SCALE; + alpha += (gain * ce_ratio) >> BBR_SCALE; + bbr->ecn_alpha = min_t(u32, alpha, BBR_UNIT); + + bbr->alpha_last_delivered = tp->delivered; + bbr->alpha_last_delivered_ce = tp->delivered_ce; + + bbr_check_ecn_too_high_in_startup(sk, ce_ratio); + return (int)ce_ratio; } -/* In theory BBR does not need to undo the cwnd since it does not - * always reduce cwnd on losses (see bbr_main()). Keep it for now. +/* Protective Load Balancing (PLB). PLB rehashes outgoing data (to a new IPv6 + * flow label) if it encounters sustained congestion in the form of ECN marks. */ -__bpf_kfunc static u32 bbr_undo_cwnd(struct sock *sk) +static void bbr_plb(struct sock *sk, const struct rate_sample *rs, int ce_ratio) +{ + struct bbr *bbr = inet_csk_ca(sk); + + if (bbr->round_start && ce_ratio >= 0) + tcp_plb_update_state(sk, &bbr->plb, ce_ratio); + + tcp_plb_check_rehash(sk, &bbr->plb); +} + +/* Each round trip of BBR_BW_PROBE_UP, double volume of probing data. */ +static void bbr_raise_inflight_hi_slope(struct sock *sk) { + struct tcp_sock *tp = tcp_sk(sk); struct bbr *bbr = inet_csk_ca(sk); + u32 growth_this_round, cnt; + + /* Calculate "slope": packets S/Acked per inflight_hi increment. */ + growth_this_round = 1 << bbr->bw_probe_up_rounds; + bbr->bw_probe_up_rounds = min(bbr->bw_probe_up_rounds + 1, 30); + cnt = tcp_snd_cwnd(tp) / growth_this_round; + cnt = max(cnt, 1U); + bbr->bw_probe_up_cnt = cnt; +} + +/* In BBR_BW_PROBE_UP, not seeing high loss/ECN/queue, so raise inflight_hi. */ +static void bbr_probe_inflight_hi_upward(struct sock *sk, + const struct rate_sample *rs) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + u32 delta; + + if (!tp->is_cwnd_limited || tcp_snd_cwnd(tp) < bbr->inflight_hi) + return; /* not fully using inflight_hi, so don't grow it */ + + /* For each bw_probe_up_cnt packets ACKed, increase inflight_hi by 1. */ + bbr->bw_probe_up_acks += rs->acked_sacked; + if (bbr->bw_probe_up_acks >= bbr->bw_probe_up_cnt) { + delta = bbr->bw_probe_up_acks / bbr->bw_probe_up_cnt; + bbr->bw_probe_up_acks -= delta * bbr->bw_probe_up_cnt; + bbr->inflight_hi += delta; + bbr->try_fast_path = 0; /* Need to update cwnd */ + } + + if (bbr->round_start) + bbr_raise_inflight_hi_slope(sk); +} + +/* Does loss/ECN rate for this sample say inflight is "too high"? + * This is used by both the bbr_check_loss_too_high_in_startup() function, + * and in PROBE_UP. + */ +static bool bbr_is_inflight_too_high(const struct sock *sk, + const struct rate_sample *rs) +{ + const struct bbr *bbr = inet_csk_ca(sk); + u32 loss_thresh, ecn_thresh; - bbr->full_bw = 0; /* spurious slow-down; reset full pipe detection */ + if (rs->lost > 0 && rs->tx_in_flight) { + loss_thresh = (u64)rs->tx_in_flight * bbr_param(sk, loss_thresh) >> + BBR_SCALE; + if (rs->lost > loss_thresh) { + return true; + } + } + + if (rs->delivered_ce > 0 && rs->delivered > 0 && + bbr->ecn_eligible && !!bbr_param(sk, ecn_thresh)) { + ecn_thresh = (u64)rs->delivered * bbr_param(sk, ecn_thresh) >> + BBR_SCALE; + if (rs->delivered_ce > ecn_thresh) { + return true; + } + } + + return false; +} + +/* Calculate the tx_in_flight level that corresponded to excessive loss. + * We find "lost_prefix" segs of the skb where loss rate went too high, + * by solving for "lost_prefix" in the following equation: + * lost / inflight >= loss_thresh + * (lost_prev + lost_prefix) / (inflight_prev + lost_prefix) >= loss_thresh + * Then we take that equation, convert it to fixed point, and + * round up to the nearest packet. + */ +static u32 bbr_inflight_hi_from_lost_skb(const struct sock *sk, + const struct rate_sample *rs, + const struct sk_buff *skb) +{ + const struct tcp_sock *tp = tcp_sk(sk); + u32 loss_thresh = bbr_param(sk, loss_thresh); + u32 pcount, divisor, inflight_hi; + s32 inflight_prev, lost_prev; + u64 loss_budget, lost_prefix; + + pcount = tcp_skb_pcount(skb); + + /* How much data was in flight before this skb? */ + inflight_prev = rs->tx_in_flight - pcount; + if (inflight_prev < 0) { + WARN_ONCE(tcp_skb_tx_in_flight_is_suspicious( + pcount, + TCP_SKB_CB(skb)->sacked, + rs->tx_in_flight), + "tx_in_flight: %u pcount: %u reneg: %u", + rs->tx_in_flight, pcount, tcp_sk(sk)->is_sack_reneg); + return ~0U; + } + + /* How much inflight data was marked lost before this skb? */ + lost_prev = rs->lost - pcount; + if (WARN_ONCE(lost_prev < 0, + "cwnd: %u ca: %d out: %u lost: %u pif: %u " + "tx_in_flight: %u tx.lost: %u tp->lost: %u rs->lost: %d " + "lost_prev: %d pcount: %d seq: %u end_seq: %u reneg: %u", + tcp_snd_cwnd(tp), inet_csk(sk)->icsk_ca_state, + tp->packets_out, tp->lost_out, tcp_packets_in_flight(tp), + rs->tx_in_flight, TCP_SKB_CB(skb)->tx.lost, tp->lost, + rs->lost, lost_prev, pcount, + TCP_SKB_CB(skb)->seq, TCP_SKB_CB(skb)->end_seq, + tp->is_sack_reneg)) + return ~0U; + + /* At what prefix of this lost skb did losss rate exceed loss_thresh? */ + loss_budget = (u64)inflight_prev * loss_thresh + BBR_UNIT - 1; + loss_budget >>= BBR_SCALE; + if (lost_prev >= loss_budget) { + lost_prefix = 0; /* previous losses crossed loss_thresh */ + } else { + lost_prefix = loss_budget - lost_prev; + lost_prefix <<= BBR_SCALE; + divisor = BBR_UNIT - loss_thresh; + if (WARN_ON_ONCE(!divisor)) /* loss_thresh is 8 bits */ + return ~0U; + do_div(lost_prefix, divisor); + } + + inflight_hi = inflight_prev + lost_prefix; + return inflight_hi; +} + +/* If loss/ECN rates during probing indicated we may have overfilled a + * buffer, return an operating point that tries to leave unutilized headroom in + * the path for other flows, for fairness convergence and lower RTTs and loss. + */ +static u32 bbr_inflight_with_headroom(const struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 headroom, headroom_fraction; + + if (bbr->inflight_hi == ~0U) + return ~0U; + + headroom_fraction = bbr_param(sk, inflight_headroom); + headroom = ((u64)bbr->inflight_hi * headroom_fraction) >> BBR_SCALE; + headroom = max(headroom, 1U); + return max_t(s32, bbr->inflight_hi - headroom, + bbr_param(sk, cwnd_min_target)); +} + +/* Bound cwnd to a sensible level, based on our current probing state + * machine phase and model of a good inflight level (inflight_lo, inflight_hi). + */ +static void bbr_bound_cwnd_for_inflight_model(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + u32 cap; + + /* tcp_rcv_synsent_state_process() currently calls tcp_ack() + * and thus cong_control() without first initializing us(!). + */ + if (!bbr->initialized) + return; + + cap = ~0U; + if (bbr->mode == BBR_PROBE_BW && + bbr->cycle_idx != BBR_BW_PROBE_CRUISE) { + /* Probe to see if more packets fit in the path. */ + cap = bbr->inflight_hi; + } else { + if (bbr->mode == BBR_PROBE_RTT || + (bbr->mode == BBR_PROBE_BW && + bbr->cycle_idx == BBR_BW_PROBE_CRUISE)) + cap = bbr_inflight_with_headroom(sk); + } + /* Adapt to any loss/ECN since our last bw probe. */ + cap = min(cap, bbr->inflight_lo); + + cap = max_t(u32, cap, bbr_param(sk, cwnd_min_target)); + tcp_snd_cwnd_set(tp, min(cap, tcp_snd_cwnd(tp))); +} + +/* How should we multiplicatively cut bw or inflight limits based on ECN? */ +static u32 bbr_ecn_cut(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + return BBR_UNIT - + ((bbr->ecn_alpha * bbr_param(sk, ecn_factor)) >> BBR_SCALE); +} + +/* Init lower bounds if have not inited yet. */ +static void bbr_init_lower_bounds(struct sock *sk, bool init_bw) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + if (init_bw && bbr->bw_lo == ~0U) + bbr->bw_lo = bbr_max_bw(sk); + if (bbr->inflight_lo == ~0U) + bbr->inflight_lo = tcp_snd_cwnd(tp); +} + +/* Reduce bw and inflight to (1 - beta). */ +static void bbr_loss_lower_bounds(struct sock *sk, u32 *bw, u32 *inflight) +{ + struct bbr* bbr = inet_csk_ca(sk); + u32 loss_cut = BBR_UNIT - bbr_param(sk, beta); + + *bw = max_t(u32, bbr->bw_latest, + (u64)bbr->bw_lo * loss_cut >> BBR_SCALE); + *inflight = max_t(u32, bbr->inflight_latest, + (u64)bbr->inflight_lo * loss_cut >> BBR_SCALE); +} + +/* Reduce inflight to (1 - alpha*ecn_factor). */ +static void bbr_ecn_lower_bounds(struct sock *sk, u32 *inflight) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 ecn_cut = bbr_ecn_cut(sk); + + *inflight = (u64)bbr->inflight_lo * ecn_cut >> BBR_SCALE; +} + +/* Estimate a short-term lower bound on the capacity available now, based + * on measurements of the current delivery process and recent history. When we + * are seeing loss/ECN at times when we are not probing bw, then conservatively + * move toward flow balance by multiplicatively cutting our short-term + * estimated safe rate and volume of data (bw_lo and inflight_lo). We use a + * multiplicative decrease in order to converge to a lower capacity in time + * logarithmic in the magnitude of the decrease. + * + * However, we do not cut our short-term estimates lower than the current rate + * and volume of delivered data from this round trip, since from the current + * delivery process we can estimate the measured capacity available now. + * + * Anything faster than that approach would knowingly risk high loss, which can + * cause low bw for Reno/CUBIC and high loss recovery latency for + * request/response flows using any congestion control. + */ +static void bbr_adapt_lower_bounds(struct sock *sk, + const struct rate_sample *rs) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 ecn_inflight_lo = ~0U; + + /* We only use lower-bound estimates when not probing bw. + * When probing we need to push inflight higher to probe bw. + */ + if (bbr_is_probing_bandwidth(sk)) + return; + + /* ECN response. */ + if (bbr->ecn_in_round && !!bbr_param(sk, ecn_factor)) { + bbr_init_lower_bounds(sk, false); + bbr_ecn_lower_bounds(sk, &ecn_inflight_lo); + } + + /* Loss response. */ + if (bbr->loss_in_round) { + bbr_init_lower_bounds(sk, true); + bbr_loss_lower_bounds(sk, &bbr->bw_lo, &bbr->inflight_lo); + } + + /* Adjust to the lower of the levels implied by loss/ECN. */ + bbr->inflight_lo = min(bbr->inflight_lo, ecn_inflight_lo); + bbr->bw_lo = max(1U, bbr->bw_lo); +} + +/* Reset any short-term lower-bound adaptation to congestion, so that we can + * push our inflight up. + */ +static void bbr_reset_lower_bounds(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + bbr->bw_lo = ~0U; + bbr->inflight_lo = ~0U; +} + +/* After bw probing (STARTUP/PROBE_UP), reset signals before entering a state + * machine phase where we adapt our lower bound based on congestion signals. + */ +static void bbr_reset_congestion_signals(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + bbr->loss_in_round = 0; + bbr->ecn_in_round = 0; + bbr->loss_in_cycle = 0; + bbr->ecn_in_cycle = 0; + bbr->bw_latest = 0; + bbr->inflight_latest = 0; +} + +static void bbr_exit_loss_recovery(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + tcp_snd_cwnd_set(tp, max(tcp_snd_cwnd(tp), bbr->prior_cwnd)); + bbr->try_fast_path = 0; /* bound cwnd using latest model */ +} + +/* Update rate and volume of delivered data from latest round trip. */ +static void bbr_update_latest_delivery_signals( + struct sock *sk, const struct rate_sample *rs, struct bbr_context *ctx) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + bbr->loss_round_start = 0; + if (rs->interval_us <= 0 || !rs->acked_sacked) + return; /* Not a valid observation */ + + bbr->bw_latest = max_t(u32, bbr->bw_latest, ctx->sample_bw); + bbr->inflight_latest = max_t(u32, bbr->inflight_latest, rs->delivered); + + if (!before(rs->prior_delivered, bbr->loss_round_delivered)) { + bbr->loss_round_delivered = tp->delivered; + bbr->loss_round_start = 1; /* mark start of new round trip */ + } +} + +/* Once per round, reset filter for latest rate and volume of delivered data. */ +static void bbr_advance_latest_delivery_signals( + struct sock *sk, const struct rate_sample *rs, struct bbr_context *ctx) +{ + struct bbr *bbr = inet_csk_ca(sk); + + /* If ACK matches a TLP retransmit, persist the filter. If we detect + * that a TLP retransmit plugged a tail loss, we'll want to remember + * how much data the path delivered before the tail loss. + */ + if (bbr->loss_round_start && !rs->is_acking_tlp_retrans_seq) { + bbr->bw_latest = ctx->sample_bw; + bbr->inflight_latest = rs->delivered; + } +} + +/* Update (most of) our congestion signals: track the recent rate and volume of + * delivered data, presence of loss, and EWMA degree of ECN marking. + */ +static void bbr_update_congestion_signals( + struct sock *sk, const struct rate_sample *rs, struct bbr_context *ctx) +{ + struct bbr *bbr = inet_csk_ca(sk); + u64 bw; + + if (rs->interval_us <= 0 || !rs->acked_sacked) + return; /* Not a valid observation */ + bw = ctx->sample_bw; + + if (!rs->is_app_limited || bw >= bbr_max_bw(sk)) + bbr_take_max_bw_sample(sk, bw); + + bbr->loss_in_round |= (rs->losses > 0); + + if (!bbr->loss_round_start) + return; /* skip the per-round-trip updates */ + /* Now do per-round-trip updates. */ + bbr_adapt_lower_bounds(sk, rs); + + bbr->loss_in_round = 0; + bbr->ecn_in_round = 0; +} + +/* Bandwidth probing can cause loss. To help coexistence with loss-based + * congestion control we spread out our probing in a Reno-conscious way. Due to + * the shape of the Reno sawtooth, the time required between loss epochs for an + * idealized Reno flow is a number of round trips that is the BDP of that + * flow. We count packet-timed round trips directly, since measured RTT can + * vary widely, and Reno is driven by packet-timed round trips. + */ +static bool bbr_is_reno_coexistence_probe_time(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 rounds; + + /* Random loss can shave some small percentage off of our inflight + * in each round. To survive this, flows need robust periodic probes. + */ + rounds = min_t(u32, bbr_param(sk, bw_probe_max_rounds), bbr_target_inflight(sk)); + return bbr->rounds_since_probe >= rounds; +} + +/* How long do we want to wait before probing for bandwidth (and risking + * loss)? We randomize the wait, for better mixing and fairness convergence. + * + * We bound the Reno-coexistence inter-bw-probe time to be 62-63 round trips. + * This is calculated to allow fairness with a 25Mbps, 30ms Reno flow, + * (eg 4K video to a broadband user): + * BDP = 25Mbps * .030sec /(1514bytes) = 61.9 packets + * + * We bound the BBR-native inter-bw-probe wall clock time to be: + * (a) higher than 2 sec: to try to avoid causing loss for a long enough time + * to allow Reno at 30ms to get 4K video bw, the inter-bw-probe time must + * be at least: 25Mbps * .030sec / (1514bytes) * 0.030sec = 1.9secs + * (b) lower than 3 sec: to ensure flows can start probing in a reasonable + * amount of time to discover unutilized bw on human-scale interactive + * time-scales (e.g. perhaps traffic from a web page download that we + * were competing with is now complete). + */ +static void bbr_pick_probe_wait(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + /* Decide the random round-trip bound for wait until probe: */ + bbr->rounds_since_probe = + get_random_u32_below(bbr_param(sk, bw_probe_rand_rounds)); + /* Decide the random wall clock bound for wait until probe: */ + bbr->probe_wait_us = bbr_param(sk, bw_probe_base_us) + + get_random_u32_below(bbr_param(sk, bw_probe_rand_us)); +} + +static void bbr_set_cycle_idx(struct sock *sk, int cycle_idx) +{ + struct bbr *bbr = inet_csk_ca(sk); + + bbr->cycle_idx = cycle_idx; + /* New phase, so need to update cwnd and pacing rate. */ + bbr->try_fast_path = 0; +} + +/* Send at estimated bw to fill the pipe, but not queue. We need this phase + * before PROBE_UP, because as soon as we send faster than the available bw + * we will start building a queue, and if the buffer is shallow we can cause + * loss. If we do not fill the pipe before we cause this loss, our bw_hi and + * inflight_hi estimates will underestimate. + */ +static void bbr_start_bw_probe_refill(struct sock *sk, u32 bw_probe_up_rounds) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + bbr_reset_lower_bounds(sk); + bbr->bw_probe_up_rounds = bw_probe_up_rounds; + bbr->bw_probe_up_acks = 0; + bbr->stopped_risky_probe = 0; + bbr->ack_phase = BBR_ACKS_REFILLING; + bbr->next_rtt_delivered = tp->delivered; + bbr_set_cycle_idx(sk, BBR_BW_PROBE_REFILL); +} + +/* Now probe max deliverable data rate and volume. */ +static void bbr_start_bw_probe_up(struct sock *sk, struct bbr_context *ctx) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + bbr->ack_phase = BBR_ACKS_PROBE_STARTING; + bbr->next_rtt_delivered = tp->delivered; + bbr->cycle_mstamp = tp->tcp_mstamp; + bbr_reset_full_bw(sk); + bbr->full_bw = ctx->sample_bw; + bbr_set_cycle_idx(sk, BBR_BW_PROBE_UP); + bbr_raise_inflight_hi_slope(sk); +} + +/* Start a new PROBE_BW probing cycle of some wall clock length. Pick a wall + * clock time at which to probe beyond an inflight that we think to be + * safe. This will knowingly risk packet loss, so we want to do this rarely, to + * keep packet loss rates low. Also start a round-trip counter, to probe faster + * if we estimate a Reno flow at our BDP would probe faster. + */ +static void bbr_start_bw_probe_down(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + bbr_reset_congestion_signals(sk); + bbr->bw_probe_up_cnt = ~0U; /* not growing inflight_hi any more */ + bbr_pick_probe_wait(sk); + bbr->cycle_mstamp = tp->tcp_mstamp; /* start wall clock */ + bbr->ack_phase = BBR_ACKS_PROBE_STOPPING; + bbr->next_rtt_delivered = tp->delivered; + bbr_set_cycle_idx(sk, BBR_BW_PROBE_DOWN); +} + +/* Cruise: maintain what we estimate to be a neutral, conservative + * operating point, without attempting to probe up for bandwidth or down for + * RTT, and only reducing inflight in response to loss/ECN signals. + */ +static void bbr_start_bw_probe_cruise(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + if (bbr->inflight_lo != ~0U) + bbr->inflight_lo = min(bbr->inflight_lo, bbr->inflight_hi); + + bbr_set_cycle_idx(sk, BBR_BW_PROBE_CRUISE); +} + +/* Loss and/or ECN rate is too high while probing. + * Adapt (once per bw probe) by cutting inflight_hi and then restarting cycle. + */ +static void bbr_handle_inflight_too_high(struct sock *sk, + const struct rate_sample *rs) +{ + struct bbr *bbr = inet_csk_ca(sk); + const u32 beta = bbr_param(sk, beta); + + bbr->prev_probe_too_high = 1; + bbr->bw_probe_samples = 0; /* only react once per probe */ + /* If we are app-limited then we are not robustly + * probing the max volume of inflight data we think + * might be safe (analogous to how app-limited bw + * samples are not known to be robustly probing bw). + */ + if (!rs->is_app_limited) { + bbr->inflight_hi = max_t(u32, rs->tx_in_flight, + (u64)bbr_target_inflight(sk) * + (BBR_UNIT - beta) >> BBR_SCALE); + } + if (bbr->mode == BBR_PROBE_BW && bbr->cycle_idx == BBR_BW_PROBE_UP) + bbr_start_bw_probe_down(sk); +} + +/* If we're seeing bw and loss samples reflecting our bw probing, adapt + * using the signals we see. If loss or ECN mark rate gets too high, then adapt + * inflight_hi downward. If we're able to push inflight higher without such + * signals, push higher: adapt inflight_hi upward. + */ +static bool bbr_adapt_upper_bounds(struct sock *sk, + const struct rate_sample *rs, + struct bbr_context *ctx) +{ + struct bbr *bbr = inet_csk_ca(sk); + + /* Track when we'll see bw/loss samples resulting from our bw probes. */ + if (bbr->ack_phase == BBR_ACKS_PROBE_STARTING && bbr->round_start) + bbr->ack_phase = BBR_ACKS_PROBE_FEEDBACK; + if (bbr->ack_phase == BBR_ACKS_PROBE_STOPPING && bbr->round_start) { + /* End of samples from bw probing phase. */ + bbr->bw_probe_samples = 0; + bbr->ack_phase = BBR_ACKS_INIT; + /* At this point in the cycle, our current bw sample is also + * our best recent chance at finding the highest available bw + * for this flow. So now is the best time to forget the bw + * samples from the previous cycle, by advancing the window. + */ + if (bbr->mode == BBR_PROBE_BW && !rs->is_app_limited) + bbr_advance_max_bw_filter(sk); + /* If we had an inflight_hi, then probed and pushed inflight all + * the way up to hit that inflight_hi without seeing any + * high loss/ECN in all the resulting ACKs from that probing, + * then probe up again, this time letting inflight persist at + * inflight_hi for a round trip, then accelerating beyond. + */ + if (bbr->mode == BBR_PROBE_BW && + bbr->stopped_risky_probe && !bbr->prev_probe_too_high) { + bbr_start_bw_probe_refill(sk, 0); + return true; /* yes, decided state transition */ + } + } + if (bbr_is_inflight_too_high(sk, rs)) { + if (bbr->bw_probe_samples) /* sample is from bw probing? */ + bbr_handle_inflight_too_high(sk, rs); + } else { + /* Loss/ECN rate is declared safe. Adjust upper bound upward. */ + + if (bbr->inflight_hi == ~0U) + return false; /* no excess queue signals yet */ + + /* To be resilient to random loss, we must raise bw/inflight_hi + * if we observe in any phase that a higher level is safe. + */ + if (rs->tx_in_flight > bbr->inflight_hi) { + bbr->inflight_hi = rs->tx_in_flight; + } + + if (bbr->mode == BBR_PROBE_BW && + bbr->cycle_idx == BBR_BW_PROBE_UP) + bbr_probe_inflight_hi_upward(sk, rs); + } + + return false; +} + +/* Check if it's time to probe for bandwidth now, and if so, kick it off. */ +static bool bbr_check_time_to_probe_bw(struct sock *sk, + const struct rate_sample *rs) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 n; + + /* If we seem to be at an operating point where we are not seeing loss + * but we are seeing ECN marks, then when the ECN marks cease we reprobe + * quickly (in case cross-traffic has ceased and freed up bw). + */ + if (bbr_param(sk, ecn_reprobe_gain) && bbr->ecn_eligible && + bbr->ecn_in_cycle && !bbr->loss_in_cycle && + inet_csk(sk)->icsk_ca_state == TCP_CA_Open) { + /* Calculate n so that when bbr_raise_inflight_hi_slope() + * computes growth_this_round as 2^n it will be roughly the + * desired volume of data (inflight_hi*ecn_reprobe_gain). + */ + n = ilog2((((u64)bbr->inflight_hi * + bbr_param(sk, ecn_reprobe_gain)) >> BBR_SCALE)); + bbr_start_bw_probe_refill(sk, n); + return true; + } + + if (bbr_has_elapsed_in_phase(sk, bbr->probe_wait_us) || + bbr_is_reno_coexistence_probe_time(sk)) { + bbr_start_bw_probe_refill(sk, 0); + return true; + } + return false; +} + +/* Is it time to transition from PROBE_DOWN to PROBE_CRUISE? */ +static bool bbr_check_time_to_cruise(struct sock *sk, u32 inflight, u32 bw) +{ + /* Always need to pull inflight down to leave headroom in queue. */ + if (inflight > bbr_inflight_with_headroom(sk)) + return false; + + return inflight <= bbr_inflight(sk, bw, BBR_UNIT); +} + +/* PROBE_BW state machine: cruise, refill, probe for bw, or drain? */ +static void bbr_update_cycle_phase(struct sock *sk, + const struct rate_sample *rs, + struct bbr_context *ctx) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + bool is_bw_probe_done = false; + u32 inflight, bw; + + if (!bbr_full_bw_reached(sk)) + return; + + /* In DRAIN, PROBE_BW, or PROBE_RTT, adjust upper bounds. */ + if (bbr_adapt_upper_bounds(sk, rs, ctx)) + return; /* already decided state transition */ + + if (bbr->mode != BBR_PROBE_BW) + return; + + inflight = bbr_packets_in_net_at_edt(sk, rs->prior_in_flight); + bw = bbr_max_bw(sk); + + switch (bbr->cycle_idx) { + /* First we spend most of our time cruising with a pacing_gain of 1.0, + * which paces at the estimated bw, to try to fully use the pipe + * without building queue. If we encounter loss/ECN marks, we adapt + * by slowing down. + */ + case BBR_BW_PROBE_CRUISE: + if (bbr_check_time_to_probe_bw(sk, rs)) + return; /* already decided state transition */ + break; + + /* After cruising, when it's time to probe, we first "refill": we send + * at the estimated bw to fill the pipe, before probing higher and + * knowingly risking overflowing the bottleneck buffer (causing loss). + */ + case BBR_BW_PROBE_REFILL: + if (bbr->round_start) { + /* After one full round trip of sending in REFILL, we + * start to see bw samples reflecting our REFILL, which + * may be putting too much data in flight. + */ + bbr->bw_probe_samples = 1; + bbr_start_bw_probe_up(sk, ctx); + } + break; + + /* After we refill the pipe, we probe by using a pacing_gain > 1.0, to + * probe for bw. If we have not seen loss/ECN, we try to raise inflight + * to at least pacing_gain*BDP; note that this may take more than + * min_rtt if min_rtt is small (e.g. on a LAN). + * + * We terminate PROBE_UP bandwidth probing upon any of the following: + * + * (1) We've pushed inflight up to hit the inflight_hi target set in the + * most recent previous bw probe phase. Thus we want to start + * draining the queue immediately because it's very likely the most + * recently sent packets will fill the queue and cause drops. + * (2) If inflight_hi has not limited bandwidth growth recently, and + * yet delivered bandwidth has not increased much recently + * (bbr->full_bw_now). + * (3) Loss filter says loss rate is "too high". + * (4) ECN filter says ECN mark rate is "too high". + * + * (1) (2) checked here, (3) (4) checked in bbr_is_inflight_too_high() + */ + case BBR_BW_PROBE_UP: + if (bbr->prev_probe_too_high && + inflight >= bbr->inflight_hi) { + bbr->stopped_risky_probe = 1; + is_bw_probe_done = true; + } else { + if (tp->is_cwnd_limited && + tcp_snd_cwnd(tp) >= bbr->inflight_hi) { + /* inflight_hi is limiting bw growth */ + bbr_reset_full_bw(sk); + bbr->full_bw = ctx->sample_bw; + } else if (bbr->full_bw_now) { + /* Plateau in estimated bw. Pipe looks full. */ + is_bw_probe_done = true; + } + } + if (is_bw_probe_done) { + bbr->prev_probe_too_high = 0; /* no loss/ECN (yet) */ + bbr_start_bw_probe_down(sk); /* restart w/ down */ + } + break; + + /* After probing in PROBE_UP, we have usually accumulated some data in + * the bottleneck buffer (if bw probing didn't find more bw). We next + * enter PROBE_DOWN to try to drain any excess data from the queue. To + * do this, we use a pacing_gain < 1.0. We hold this pacing gain until + * our inflight is less then that target cruising point, which is the + * minimum of (a) the amount needed to leave headroom, and (b) the + * estimated BDP. Once inflight falls to match the target, we estimate + * the queue is drained; persisting would underutilize the pipe. + */ + case BBR_BW_PROBE_DOWN: + if (bbr_check_time_to_probe_bw(sk, rs)) + return; /* already decided state transition */ + if (bbr_check_time_to_cruise(sk, inflight, bw)) + bbr_start_bw_probe_cruise(sk); + break; + + default: + WARN_ONCE(1, "BBR invalid cycle index %u\n", bbr->cycle_idx); + } +} + +/* Exiting PROBE_RTT, so return to bandwidth probing in STARTUP or PROBE_BW. */ +static void bbr_exit_probe_rtt(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + bbr_reset_lower_bounds(sk); + if (bbr_full_bw_reached(sk)) { + bbr->mode = BBR_PROBE_BW; + /* Raising inflight after PROBE_RTT may cause loss, so reset + * the PROBE_BW clock and schedule the next bandwidth probe for + * a friendly and randomized future point in time. + */ + bbr_start_bw_probe_down(sk); + /* Since we are exiting PROBE_RTT, we know inflight is + * below our estimated BDP, so it is reasonable to cruise. + */ + bbr_start_bw_probe_cruise(sk); + } else { + bbr->mode = BBR_STARTUP; + } +} + +/* Exit STARTUP based on loss rate > 1% and loss gaps in round >= N. Wait until + * the end of the round in recovery to get a good estimate of how many packets + * have been lost, and how many we need to drain with a low pacing rate. + */ +static void bbr_check_loss_too_high_in_startup(struct sock *sk, + const struct rate_sample *rs) +{ + struct bbr *bbr = inet_csk_ca(sk); + + if (bbr_full_bw_reached(sk)) + return; + + /* For STARTUP exit, check the loss rate at the end of each round trip + * of Recovery episodes in STARTUP. We check the loss rate at the end + * of the round trip to filter out noisy/low loss and have a better + * sense of inflight (extent of loss), so we can drain more accurately. + */ + if (rs->losses && bbr->loss_events_in_round < 0xf) + bbr->loss_events_in_round++; /* update saturating counter */ + if (bbr_param(sk, full_loss_cnt) && bbr->loss_round_start && + inet_csk(sk)->icsk_ca_state == TCP_CA_Recovery && + bbr->loss_events_in_round >= bbr_param(sk, full_loss_cnt) && + bbr_is_inflight_too_high(sk, rs)) { + bbr_handle_queue_too_high_in_startup(sk); + return; + } + if (bbr->loss_round_start) + bbr->loss_events_in_round = 0; +} + +/* Estimate when the pipe is full, using the change in delivery rate: BBR + * estimates bw probing filled the pipe if the estimated bw hasn't changed by + * at least bbr_full_bw_thresh (25%) after bbr_full_bw_cnt (3) non-app-limited + * rounds. Why 3 rounds: 1: rwin autotuning grows the rwin, 2: we fill the + * higher rwin, 3: we get higher delivery rate samples. Or transient + * cross-traffic or radio noise can go away. CUBIC Hystart shares a similar + * design goal, but uses delay and inter-ACK spacing instead of bandwidth. + */ +static void bbr_check_full_bw_reached(struct sock *sk, + const struct rate_sample *rs, + struct bbr_context *ctx) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 bw_thresh, full_cnt, thresh; + + if (bbr->full_bw_now || rs->is_app_limited) + return; + + thresh = bbr_param(sk, full_bw_thresh); + full_cnt = bbr_param(sk, full_bw_cnt); + bw_thresh = (u64)bbr->full_bw * thresh >> BBR_SCALE; + if (ctx->sample_bw >= bw_thresh) { + bbr_reset_full_bw(sk); + bbr->full_bw = ctx->sample_bw; + return; + } + if (!bbr->round_start) + return; + ++bbr->full_bw_cnt; + bbr->full_bw_now = bbr->full_bw_cnt >= full_cnt; + bbr->full_bw_reached |= bbr->full_bw_now; +} + +/* If pipe is probably full, drain the queue and then enter steady-state. */ +static void bbr_check_drain(struct sock *sk, const struct rate_sample *rs, + struct bbr_context *ctx) +{ + struct bbr *bbr = inet_csk_ca(sk); + + if (bbr->mode == BBR_STARTUP && bbr_full_bw_reached(sk)) { + bbr->mode = BBR_DRAIN; /* drain queue we created */ + /* Set ssthresh to export purely for monitoring, to signal + * completion of initial STARTUP by setting to a non- + * TCP_INFINITE_SSTHRESH value (ssthresh is not used by BBR). + */ + tcp_sk(sk)->snd_ssthresh = + bbr_inflight(sk, bbr_max_bw(sk), BBR_UNIT); + bbr_reset_congestion_signals(sk); + } /* fall through to check if in-flight is already small: */ + if (bbr->mode == BBR_DRAIN && + bbr_packets_in_net_at_edt(sk, tcp_packets_in_flight(tcp_sk(sk))) <= + bbr_inflight(sk, bbr_max_bw(sk), BBR_UNIT)) { + bbr->mode = BBR_PROBE_BW; + bbr_start_bw_probe_down(sk); + } +} + +static void bbr_update_model(struct sock *sk, const struct rate_sample *rs, + struct bbr_context *ctx) +{ + bbr_update_congestion_signals(sk, rs, ctx); + bbr_update_ack_aggregation(sk, rs); + bbr_check_loss_too_high_in_startup(sk, rs); + bbr_check_full_bw_reached(sk, rs, ctx); + bbr_check_drain(sk, rs, ctx); + bbr_update_cycle_phase(sk, rs, ctx); + bbr_update_min_rtt(sk, rs); +} + +/* Fast path for app-limited case. + * + * On each ack, we execute bbr state machine, which primarily consists of: + * 1) update model based on new rate sample, and + * 2) update control based on updated model or state change. + * + * There are certain workload/scenarios, e.g. app-limited case, where + * either we can skip updating model or we can skip update of both model + * as well as control. This provides signifcant softirq cpu savings for + * processing incoming acks. + * + * In case of app-limited, if there is no congestion (loss/ecn) and + * if observed bw sample is less than current estimated bw, then we can + * skip some of the computation in bbr state processing: + * + * - if there is no rtt/mode/phase change: In this case, since all the + * parameters of the network model are constant, we can skip model + * as well control update. + * + * - else we can skip rest of the model update. But we still need to + * update the control to account for the new rtt/mode/phase. + * + * Returns whether we can take fast path or not. + */ +static bool bbr_run_fast_path(struct sock *sk, bool *update_model, + const struct rate_sample *rs, struct bbr_context *ctx) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 prev_min_rtt_us, prev_mode; + + if (bbr_param(sk, fast_path) && bbr->try_fast_path && + rs->is_app_limited && ctx->sample_bw < bbr_max_bw(sk) && + !bbr->loss_in_round && !bbr->ecn_in_round ) { + prev_mode = bbr->mode; + prev_min_rtt_us = bbr->min_rtt_us; + bbr_check_drain(sk, rs, ctx); + bbr_update_cycle_phase(sk, rs, ctx); + bbr_update_min_rtt(sk, rs); + + if (bbr->mode == prev_mode && + bbr->min_rtt_us == prev_min_rtt_us && + bbr->try_fast_path) { + return true; + } + + /* Skip model update, but control still needs to be updated */ + *update_model = false; + } + return false; +} + +__bpf_kfunc static void bbr_main(struct sock *sk, u32 ack, int flag, + const struct rate_sample *rs) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + struct bbr_context ctx = { 0 }; + bool update_model = true; + u32 bw, round_delivered; + int ce_ratio = -1; + + round_delivered = bbr_update_round_start(sk, rs, &ctx); + if (bbr->round_start) { + bbr->rounds_since_probe = + min_t(s32, bbr->rounds_since_probe + 1, 0xFF); + ce_ratio = bbr_update_ecn_alpha(sk); + } + bbr_plb(sk, rs, ce_ratio); + + bbr->ecn_in_round |= (bbr->ecn_eligible && rs->is_ece); + bbr_calculate_bw_sample(sk, rs, &ctx); + bbr_update_latest_delivery_signals(sk, rs, &ctx); + + if (bbr_run_fast_path(sk, &update_model, rs, &ctx)) + goto out; + + if (update_model) + bbr_update_model(sk, rs, &ctx); + + bbr_update_gains(sk); + bw = bbr_bw(sk); + bbr_set_pacing_rate(sk, bw, bbr->pacing_gain); + bbr_set_cwnd(sk, rs, rs->acked_sacked, bw, bbr->cwnd_gain, + tcp_snd_cwnd(tp), &ctx); + bbr_bound_cwnd_for_inflight_model(sk); + +out: + bbr_advance_latest_delivery_signals(sk, rs, &ctx); + bbr->prev_ca_state = inet_csk(sk)->icsk_ca_state; + bbr->loss_in_cycle |= rs->lost > 0; + bbr->ecn_in_cycle |= rs->delivered_ce > 0; +} + +__bpf_kfunc static void bbr_init(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + bbr->initialized = 1; + + bbr->init_cwnd = min(0x7FU, tcp_snd_cwnd(tp)); + bbr->prior_cwnd = tp->prior_cwnd; + tp->snd_ssthresh = TCP_INFINITE_SSTHRESH; + bbr->next_rtt_delivered = tp->delivered; + bbr->prev_ca_state = TCP_CA_Open; + + bbr->probe_rtt_done_stamp = 0; + bbr->probe_rtt_round_done = 0; + bbr->probe_rtt_min_us = tcp_min_rtt(tp); + bbr->probe_rtt_min_stamp = tcp_jiffies32; + bbr->min_rtt_us = tcp_min_rtt(tp); + bbr->min_rtt_stamp = tcp_jiffies32; + + bbr->has_seen_rtt = 0; + bbr_init_pacing_rate_from_rtt(sk); + + bbr->round_start = 0; + bbr->idle_restart = 0; + bbr->full_bw_reached = 0; + bbr->full_bw = 0; bbr->full_bw_cnt = 0; - bbr_reset_lt_bw_sampling(sk); - return tcp_snd_cwnd(tcp_sk(sk)); + bbr->cycle_mstamp = 0; + bbr->cycle_idx = 0; + + bbr_reset_startup_mode(sk); + + bbr->ack_epoch_mstamp = tp->tcp_mstamp; + bbr->ack_epoch_acked = 0; + bbr->extra_acked_win_rtts = 0; + bbr->extra_acked_win_idx = 0; + bbr->extra_acked[0] = 0; + bbr->extra_acked[1] = 0; + + bbr->ce_state = 0; + bbr->prior_rcv_nxt = tp->rcv_nxt; + bbr->try_fast_path = 0; + + cmpxchg(&sk->sk_pacing_status, SK_PACING_NONE, SK_PACING_NEEDED); + + /* Start sampling ECN mark rate after first full flight is ACKed: */ + bbr->loss_round_delivered = tp->delivered + 1; + bbr->loss_round_start = 0; + bbr->undo_bw_lo = 0; + bbr->undo_inflight_lo = 0; + bbr->undo_inflight_hi = 0; + bbr->loss_events_in_round = 0; + bbr->startup_ecn_rounds = 0; + bbr_reset_congestion_signals(sk); + bbr->bw_lo = ~0U; + bbr->bw_hi[0] = 0; + bbr->bw_hi[1] = 0; + bbr->inflight_lo = ~0U; + bbr->inflight_hi = ~0U; + bbr_reset_full_bw(sk); + bbr->bw_probe_up_cnt = ~0U; + bbr->bw_probe_up_acks = 0; + bbr->bw_probe_up_rounds = 0; + bbr->probe_wait_us = 0; + bbr->stopped_risky_probe = 0; + bbr->ack_phase = BBR_ACKS_INIT; + bbr->rounds_since_probe = 0; + bbr->bw_probe_samples = 0; + bbr->prev_probe_too_high = 0; + bbr->ecn_eligible = 0; + bbr->ecn_alpha = bbr_param(sk, ecn_alpha_init); + bbr->alpha_last_delivered = 0; + bbr->alpha_last_delivered_ce = 0; + bbr->plb.pause_until = 0; + + tp->fast_ack_mode = bbr_fast_ack_mode ? 1 : 0; + + if (bbr_can_use_ecn(sk)) + tp->ecn_flags |= TCP_ECN_ECT_PERMANENT; +} + +/* BBR marks the current round trip as a loss round. */ +static void bbr_note_loss(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + /* Capture "current" data over the full round trip of loss, to + * have a better chance of observing the full capacity of the path. + */ + if (!bbr->loss_in_round) /* first loss in this round trip? */ + bbr->loss_round_delivered = tp->delivered; /* set round trip */ + bbr->loss_in_round = 1; + bbr->loss_in_cycle = 1; } -/* Entering loss recovery, so save cwnd for when we exit or undo recovery. */ +/* Core TCP stack informs us that the given skb was just marked lost. */ +__bpf_kfunc static void bbr_skb_marked_lost(struct sock *sk, + const struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + struct tcp_skb_cb *scb = TCP_SKB_CB(skb); + struct rate_sample rs = {}; + + bbr_note_loss(sk); + + if (!bbr->bw_probe_samples) + return; /* not an skb sent while probing for bandwidth */ + if (unlikely(!scb->tx.delivered_mstamp)) + return; /* skb was SACKed, reneged, marked lost; ignore it */ + /* We are probing for bandwidth. Construct a rate sample that + * estimates what happened in the flight leading up to this lost skb, + * then see if the loss rate went too high, and if so at which packet. + */ + rs.tx_in_flight = scb->tx.in_flight; + rs.lost = tp->lost - scb->tx.lost; + rs.is_app_limited = scb->tx.is_app_limited; + if (bbr_is_inflight_too_high(sk, &rs)) { + rs.tx_in_flight = bbr_inflight_hi_from_lost_skb(sk, &rs, skb); + bbr_handle_inflight_too_high(sk, &rs); + } +} + +static void bbr_run_loss_probe_recovery(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + struct rate_sample rs = {0}; + + bbr_note_loss(sk); + + if (!bbr->bw_probe_samples) + return; /* not sent while probing for bandwidth */ + /* We are probing for bandwidth. Construct a rate sample that + * estimates what happened in the flight leading up to this + * loss, then see if the loss rate went too high. + */ + rs.lost = 1; /* TLP probe repaired loss of a single segment */ + rs.tx_in_flight = bbr->inflight_latest + rs.lost; + rs.is_app_limited = tp->tlp_orig_data_app_limited; + if (bbr_is_inflight_too_high(sk, &rs)) + bbr_handle_inflight_too_high(sk, &rs); +} + +/* Revert short-term model if current loss recovery event was spurious. */ +__bpf_kfunc static u32 bbr_undo_cwnd(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + bbr_reset_full_bw(sk); /* spurious slow-down; reset full bw detector */ + bbr->loss_in_round = 0; + + /* Revert to cwnd and other state saved before loss episode. */ + bbr->bw_lo = max(bbr->bw_lo, bbr->undo_bw_lo); + bbr->inflight_lo = max(bbr->inflight_lo, bbr->undo_inflight_lo); + bbr->inflight_hi = max(bbr->inflight_hi, bbr->undo_inflight_hi); + bbr->try_fast_path = 0; /* take slow path to set proper cwnd, pacing */ + return bbr->prior_cwnd; +} + +/* Entering loss recovery, so save state for when we undo recovery. */ __bpf_kfunc static u32 bbr_ssthresh(struct sock *sk) { + struct bbr *bbr = inet_csk_ca(sk); + bbr_save_cwnd(sk); + /* For undo, save state that adapts based on loss signal. */ + bbr->undo_bw_lo = bbr->bw_lo; + bbr->undo_inflight_lo = bbr->inflight_lo; + bbr->undo_inflight_hi = bbr->inflight_hi; return tcp_sk(sk)->snd_ssthresh; } +static enum tcp_bbr_phase bbr_get_phase(struct bbr *bbr) +{ + switch (bbr->mode) { + case BBR_STARTUP: + return BBR_PHASE_STARTUP; + case BBR_DRAIN: + return BBR_PHASE_DRAIN; + case BBR_PROBE_BW: + break; + case BBR_PROBE_RTT: + return BBR_PHASE_PROBE_RTT; + default: + return BBR_PHASE_INVALID; + } + switch (bbr->cycle_idx) { + case BBR_BW_PROBE_UP: + return BBR_PHASE_PROBE_BW_UP; + case BBR_BW_PROBE_DOWN: + return BBR_PHASE_PROBE_BW_DOWN; + case BBR_BW_PROBE_CRUISE: + return BBR_PHASE_PROBE_BW_CRUISE; + case BBR_BW_PROBE_REFILL: + return BBR_PHASE_PROBE_BW_REFILL; + default: + return BBR_PHASE_INVALID; + } +} + static size_t bbr_get_info(struct sock *sk, u32 ext, int *attr, - union tcp_cc_info *info) + union tcp_cc_info *info) { if (ext & (1 << (INET_DIAG_BBRINFO - 1)) || ext & (1 << (INET_DIAG_VEGASINFO - 1))) { - struct tcp_sock *tp = tcp_sk(sk); struct bbr *bbr = inet_csk_ca(sk); - u64 bw = bbr_bw(sk); - - bw = bw * tp->mss_cache * USEC_PER_SEC >> BW_SCALE; - memset(&info->bbr, 0, sizeof(info->bbr)); - info->bbr.bbr_bw_lo = (u32)bw; - info->bbr.bbr_bw_hi = (u32)(bw >> 32); - info->bbr.bbr_min_rtt = bbr->min_rtt_us; - info->bbr.bbr_pacing_gain = bbr->pacing_gain; - info->bbr.bbr_cwnd_gain = bbr->cwnd_gain; + u64 bw = bbr_bw_bytes_per_sec(sk, bbr_bw(sk)); + u64 bw_hi = bbr_bw_bytes_per_sec(sk, bbr_max_bw(sk)); + u64 bw_lo = bbr->bw_lo == ~0U ? + ~0ULL : bbr_bw_bytes_per_sec(sk, bbr->bw_lo); + struct tcp_bbr_info *bbr_info = &info->bbr; + + memset(bbr_info, 0, sizeof(*bbr_info)); + bbr_info->bbr_bw_lo = (u32)bw; + bbr_info->bbr_bw_hi = (u32)(bw >> 32); + bbr_info->bbr_min_rtt = bbr->min_rtt_us; + bbr_info->bbr_pacing_gain = bbr->pacing_gain; + bbr_info->bbr_cwnd_gain = bbr->cwnd_gain; + bbr_info->bbr_bw_hi_lsb = (u32)bw_hi; + bbr_info->bbr_bw_hi_msb = (u32)(bw_hi >> 32); + bbr_info->bbr_bw_lo_lsb = (u32)bw_lo; + bbr_info->bbr_bw_lo_msb = (u32)(bw_lo >> 32); + bbr_info->bbr_mode = bbr->mode; + bbr_info->bbr_phase = (__u8)bbr_get_phase(bbr); + bbr_info->bbr_version = (__u8)BBR_VERSION; + bbr_info->bbr_inflight_lo = bbr->inflight_lo; + bbr_info->bbr_inflight_hi = bbr->inflight_hi; + bbr_info->bbr_extra_acked = bbr_extra_acked(sk); *attr = INET_DIAG_BBRINFO; - return sizeof(info->bbr); + return sizeof(*bbr_info); } return 0; } __bpf_kfunc static void bbr_set_state(struct sock *sk, u8 new_state) { + struct tcp_sock *tp = tcp_sk(sk); struct bbr *bbr = inet_csk_ca(sk); if (new_state == TCP_CA_Loss) { - struct rate_sample rs = { .losses = 1 }; bbr->prev_ca_state = TCP_CA_Loss; - bbr->full_bw = 0; - bbr->round_start = 1; /* treat RTO like end of a round */ - bbr_lt_bw_sampling(sk, &rs); + tcp_plb_update_state_upon_rto(sk, &bbr->plb); + /* The tcp_write_timeout() call to sk_rethink_txhash() likely + * repathed this flow, so re-learn the min network RTT on the + * new path: + */ + bbr_reset_full_bw(sk); + if (!bbr_is_probing_bandwidth(sk) && bbr->inflight_lo == ~0U) { + /* bbr_adapt_lower_bounds() needs cwnd before + * we suffered an RTO, to update inflight_lo: + */ + bbr->inflight_lo = + max(tcp_snd_cwnd(tp), bbr->prior_cwnd); + } + } else if (bbr->prev_ca_state == TCP_CA_Loss && + new_state != TCP_CA_Loss) { + bbr_exit_loss_recovery(sk); } } + static struct tcp_congestion_ops tcp_bbr_cong_ops __read_mostly = { - .flags = TCP_CONG_NON_RESTRICTED, + .flags = TCP_CONG_NON_RESTRICTED | TCP_CONG_WANTS_CE_EVENTS, .name = "bbr", .owner = THIS_MODULE, .init = bbr_init, .cong_control = bbr_main, .sndbuf_expand = bbr_sndbuf_expand, + .skb_marked_lost = bbr_skb_marked_lost, .undo_cwnd = bbr_undo_cwnd, .cwnd_event = bbr_cwnd_event, .ssthresh = bbr_ssthresh, - .min_tso_segs = bbr_min_tso_segs, + .tso_segs = bbr_tso_segs, .get_info = bbr_get_info, .set_state = bbr_set_state, }; @@ -1159,10 +2360,11 @@ BTF_KFUNCS_START(tcp_bbr_check_kfunc_ids) BTF_ID_FLAGS(func, bbr_init) BTF_ID_FLAGS(func, bbr_main) BTF_ID_FLAGS(func, bbr_sndbuf_expand) +BTF_ID_FLAGS(func, bbr_skb_marked_lost) BTF_ID_FLAGS(func, bbr_undo_cwnd) BTF_ID_FLAGS(func, bbr_cwnd_event) BTF_ID_FLAGS(func, bbr_ssthresh) -BTF_ID_FLAGS(func, bbr_min_tso_segs) +BTF_ID_FLAGS(func, bbr_tso_segs) BTF_ID_FLAGS(func, bbr_set_state) BTF_KFUNCS_END(tcp_bbr_check_kfunc_ids) @@ -1195,5 +2397,12 @@ MODULE_AUTHOR("Van Jacobson "); MODULE_AUTHOR("Neal Cardwell "); MODULE_AUTHOR("Yuchung Cheng "); MODULE_AUTHOR("Soheil Hassas Yeganeh "); +MODULE_AUTHOR("Priyaranjan Jha "); +MODULE_AUTHOR("Yousuk Seung "); +MODULE_AUTHOR("Kevin Yang "); +MODULE_AUTHOR("Arjun Roy "); +MODULE_AUTHOR("David Morley "); + MODULE_LICENSE("Dual BSD/GPL"); MODULE_DESCRIPTION("TCP BBR (Bottleneck Bandwidth and RTT)"); +MODULE_VERSION(__stringify(BBR_VERSION)); diff --git a/net/ipv4/tcp_cong.c b/net/ipv4/tcp_cong.c index df758adbb445..e98e5dbc050e 100644 --- a/net/ipv4/tcp_cong.c +++ b/net/ipv4/tcp_cong.c @@ -237,6 +237,7 @@ void tcp_init_congestion_control(struct sock *sk) struct inet_connection_sock *icsk = inet_csk(sk); tcp_sk(sk)->prior_ssthresh = 0; + tcp_sk(sk)->fast_ack_mode = 0; if (icsk->icsk_ca_ops->init) icsk->icsk_ca_ops->init(sk); if (tcp_ca_needs_ecn(sk)) diff --git a/net/ipv4/tcp_input.c b/net/ipv4/tcp_input.c index 0cbf81bf3d45..7e8324f54563 100644 --- a/net/ipv4/tcp_input.c +++ b/net/ipv4/tcp_input.c @@ -376,7 +376,7 @@ static void __tcp_ecn_check_ce(struct sock *sk, const struct sk_buff *skb) tcp_enter_quickack_mode(sk, 2); break; case INET_ECN_CE: - if (tcp_ca_needs_ecn(sk)) + if (tcp_ca_wants_ce_events(sk)) tcp_ca_event(sk, CA_EVENT_ECN_IS_CE); if (!(tp->ecn_flags & TCP_ECN_DEMAND_CWR)) { @@ -387,7 +387,7 @@ static void __tcp_ecn_check_ce(struct sock *sk, const struct sk_buff *skb) tp->ecn_flags |= TCP_ECN_SEEN; break; default: - if (tcp_ca_needs_ecn(sk)) + if (tcp_ca_wants_ce_events(sk)) tcp_ca_event(sk, CA_EVENT_ECN_NO_CE); tp->ecn_flags |= TCP_ECN_SEEN; break; @@ -1126,7 +1126,12 @@ static void tcp_verify_retransmit_hint(struct tcp_sock *tp, struct sk_buff *skb) */ static void tcp_notify_skb_loss_event(struct tcp_sock *tp, const struct sk_buff *skb) { + struct sock *sk = (struct sock *)tp; + const struct tcp_congestion_ops *ca_ops = inet_csk(sk)->icsk_ca_ops; + tp->lost += tcp_skb_pcount(skb); + if (ca_ops->skb_marked_lost) + ca_ops->skb_marked_lost(sk, skb); } void tcp_mark_skb_lost(struct sock *sk, struct sk_buff *skb) @@ -1507,6 +1512,17 @@ static bool tcp_shifted_skb(struct sock *sk, struct sk_buff *prev, WARN_ON_ONCE(tcp_skb_pcount(skb) < pcount); tcp_skb_pcount_add(skb, -pcount); + /* Adjust tx.in_flight as pcount is shifted from skb to prev. */ + if (WARN_ONCE(TCP_SKB_CB(skb)->tx.in_flight < pcount, + "prev in_flight: %u skb in_flight: %u pcount: %u", + TCP_SKB_CB(prev)->tx.in_flight, + TCP_SKB_CB(skb)->tx.in_flight, + pcount)) + TCP_SKB_CB(skb)->tx.in_flight = 0; + else + TCP_SKB_CB(skb)->tx.in_flight -= pcount; + TCP_SKB_CB(prev)->tx.in_flight += pcount; + /* When we're adding to gso_segs == 1, gso_size will be zero, * in theory this shouldn't be necessary but as long as DSACK * code can come after this skb later on it's better to keep @@ -3832,7 +3848,8 @@ static void tcp_replace_ts_recent(struct tcp_sock *tp, u32 seq) /* This routine deals with acks during a TLP episode and ends an episode by * resetting tlp_high_seq. Ref: TLP algorithm in draft-ietf-tcpm-rack */ -static void tcp_process_tlp_ack(struct sock *sk, u32 ack, int flag) +static void tcp_process_tlp_ack(struct sock *sk, u32 ack, int flag, + struct rate_sample *rs) { struct tcp_sock *tp = tcp_sk(sk); @@ -3849,6 +3866,7 @@ static void tcp_process_tlp_ack(struct sock *sk, u32 ack, int flag) /* ACK advances: there was a loss, so reduce cwnd. Reset * tlp_high_seq in tcp_init_cwnd_reduction() */ + tcp_ca_event(sk, CA_EVENT_TLP_RECOVERY); tcp_init_cwnd_reduction(sk); tcp_set_ca_state(sk, TCP_CA_CWR); tcp_end_cwnd_reduction(sk); @@ -3859,6 +3877,11 @@ static void tcp_process_tlp_ack(struct sock *sk, u32 ack, int flag) FLAG_NOT_DUP | FLAG_DATA_SACKED))) { /* Pure dupack: original and TLP probe arrived; no loss */ tp->tlp_high_seq = 0; + } else { + /* This ACK matches a TLP retransmit. We cannot yet tell if + * this ACK is for the original or the TLP retransmit. + */ + rs->is_acking_tlp_retrans_seq = 1; } } @@ -3967,6 +3990,7 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) prior_fack = tcp_is_sack(tp) ? tcp_highest_sack_seq(tp) : tp->snd_una; rs.prior_in_flight = tcp_packets_in_flight(tp); + tcp_rate_check_app_limited(sk); /* ts_recent update must be made after we are sure that the packet * is in window. @@ -4041,7 +4065,7 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) tcp_rack_update_reo_wnd(sk, &rs); if (tp->tlp_high_seq) - tcp_process_tlp_ack(sk, ack, flag); + tcp_process_tlp_ack(sk, ack, flag, &rs); if (tcp_ack_is_dubious(sk, flag)) { if (!(flag & (FLAG_SND_UNA_ADVANCED | @@ -4065,6 +4089,7 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) delivered = tcp_newly_delivered(sk, delivered, flag); lost = tp->lost - lost; /* freshly marked lost */ rs.is_ack_delayed = !!(flag & FLAG_ACK_MAYBE_DELAYED); + rs.is_ece = !!(flag & FLAG_ECE); tcp_rate_gen(sk, delivered, lost, is_sack_reneg, sack_state.rate); tcp_cong_control(sk, ack, delivered, flag, sack_state.rate); tcp_xmit_recovery(sk, rexmit); @@ -4084,7 +4109,7 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) tcp_ack_probe(sk); if (tp->tlp_high_seq) - tcp_process_tlp_ack(sk, ack, flag); + tcp_process_tlp_ack(sk, ack, flag, &rs); return 1; old_ack: @@ -5764,13 +5789,14 @@ static void __tcp_ack_snd_check(struct sock *sk, int ofo_possible) /* More than one full frame received... */ if (((tp->rcv_nxt - tp->rcv_wup) > inet_csk(sk)->icsk_ack.rcv_mss && + (tp->fast_ack_mode == 1 || /* ... and right edge of window advances far enough. * (tcp_recvmsg() will send ACK otherwise). * If application uses SO_RCVLOWAT, we want send ack now if * we have not received enough bytes to satisfy the condition. */ - (tp->rcv_nxt - tp->copied_seq < sk->sk_rcvlowat || - __tcp_select_window(sk) >= tp->rcv_wnd)) || + (tp->rcv_nxt - tp->copied_seq < sk->sk_rcvlowat || + __tcp_select_window(sk) >= tp->rcv_wnd))) || /* We ACK each frame or... */ tcp_in_quickack_mode(sk) || /* Protocol state mandates a one-time immediate ACK */ diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c index dfdb7a4608a8..874e99902bba 100644 --- a/net/ipv4/tcp_minisocks.c +++ b/net/ipv4/tcp_minisocks.c @@ -471,6 +471,8 @@ void tcp_ca_openreq_child(struct sock *sk, const struct dst_entry *dst) u32 ca_key = dst_metric(dst, RTAX_CC_ALGO); bool ca_got_dst = false; + tcp_set_ecn_low_from_dst(sk, dst); + if (ca_key != TCP_CA_UNSPEC) { const struct tcp_congestion_ops *ca; diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c index bc95d2a5924f..d4c45ca6fe06 100644 --- a/net/ipv4/tcp_output.c +++ b/net/ipv4/tcp_output.c @@ -339,10 +339,9 @@ static void tcp_ecn_send_syn(struct sock *sk, struct sk_buff *skb) bool bpf_needs_ecn = tcp_bpf_ca_needs_ecn(sk); bool use_ecn = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_ecn) == 1 || tcp_ca_needs_ecn(sk) || bpf_needs_ecn; + const struct dst_entry *dst = __sk_dst_get(sk); if (!use_ecn) { - const struct dst_entry *dst = __sk_dst_get(sk); - if (dst && dst_feature(dst, RTAX_FEATURE_ECN)) use_ecn = true; } @@ -354,6 +353,9 @@ static void tcp_ecn_send_syn(struct sock *sk, struct sk_buff *skb) tp->ecn_flags = TCP_ECN_OK; if (tcp_ca_needs_ecn(sk) || bpf_needs_ecn) INET_ECN_xmit(sk); + + if (dst) + tcp_set_ecn_low_from_dst(sk, dst); } } @@ -391,7 +393,8 @@ static void tcp_ecn_send(struct sock *sk, struct sk_buff *skb, th->cwr = 1; skb_shinfo(skb)->gso_type |= SKB_GSO_TCP_ECN; } - } else if (!tcp_ca_needs_ecn(sk)) { + } else if (!(tp->ecn_flags & TCP_ECN_ECT_PERMANENT) && + !tcp_ca_needs_ecn(sk)) { /* ACK or retransmitted segment: clear ECT|CE */ INET_ECN_dontxmit(sk); } @@ -1606,7 +1609,7 @@ int tcp_fragment(struct sock *sk, enum tcp_queue tcp_queue, { struct tcp_sock *tp = tcp_sk(sk); struct sk_buff *buff; - int old_factor; + int old_factor, inflight_prev; long limit; int nlen; u8 flags; @@ -1681,6 +1684,30 @@ int tcp_fragment(struct sock *sk, enum tcp_queue tcp_queue, if (diff) tcp_adjust_pcount(sk, skb, diff); + + inflight_prev = TCP_SKB_CB(skb)->tx.in_flight - old_factor; + if (inflight_prev < 0) { + WARN_ONCE(tcp_skb_tx_in_flight_is_suspicious( + old_factor, + TCP_SKB_CB(skb)->sacked, + TCP_SKB_CB(skb)->tx.in_flight), + "inconsistent: tx.in_flight: %u " + "old_factor: %d mss: %u sacked: %u " + "1st pcount: %d 2nd pcount: %d " + "1st len: %u 2nd len: %u ", + TCP_SKB_CB(skb)->tx.in_flight, old_factor, + mss_now, TCP_SKB_CB(skb)->sacked, + tcp_skb_pcount(skb), tcp_skb_pcount(buff), + skb->len, buff->len); + inflight_prev = 0; + } + /* Set 1st tx.in_flight as if 1st were sent by itself: */ + TCP_SKB_CB(skb)->tx.in_flight = inflight_prev + + tcp_skb_pcount(skb); + /* Set 2nd tx.in_flight with new 1st and 2nd pcounts: */ + TCP_SKB_CB(buff)->tx.in_flight = inflight_prev + + tcp_skb_pcount(skb) + + tcp_skb_pcount(buff); } /* Link BUFF into the send queue. */ @@ -2038,13 +2065,12 @@ static u32 tcp_tso_autosize(const struct sock *sk, unsigned int mss_now, static u32 tcp_tso_segs(struct sock *sk, unsigned int mss_now) { const struct tcp_congestion_ops *ca_ops = inet_csk(sk)->icsk_ca_ops; - u32 min_tso, tso_segs; - - min_tso = ca_ops->min_tso_segs ? - ca_ops->min_tso_segs(sk) : - READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_min_tso_segs); + u32 tso_segs; - tso_segs = tcp_tso_autosize(sk, mss_now, min_tso); + tso_segs = ca_ops->tso_segs ? + ca_ops->tso_segs(sk, mss_now) : + tcp_tso_autosize(sk, mss_now, + sock_net(sk)->ipv4.sysctl_tcp_min_tso_segs); return min_t(u32, tso_segs, sk->sk_gso_max_segs); } @@ -2770,6 +2796,7 @@ static bool tcp_write_xmit(struct sock *sk, unsigned int mss_now, int nonagle, skb_set_delivery_time(skb, tp->tcp_wstamp_ns, SKB_CLOCK_MONOTONIC); list_move_tail(&skb->tcp_tsorted_anchor, &tp->tsorted_sent_queue); tcp_init_tso_segs(skb, mss_now); + tcp_set_tx_in_flight(sk, skb); goto repair; /* Skip network transmission */ } @@ -2982,6 +3009,7 @@ void tcp_send_loss_probe(struct sock *sk) if (WARN_ON(!skb || !tcp_skb_pcount(skb))) goto rearm_timer; + tp->tlp_orig_data_app_limited = TCP_SKB_CB(skb)->tx.is_app_limited; if (__tcp_retransmit_skb(sk, skb, 1)) goto rearm_timer; diff --git a/net/ipv4/tcp_rate.c b/net/ipv4/tcp_rate.c index a8f6d9d06f2e..8737f2134648 100644 --- a/net/ipv4/tcp_rate.c +++ b/net/ipv4/tcp_rate.c @@ -34,6 +34,24 @@ * ready to send in the write queue. */ +void tcp_set_tx_in_flight(struct sock *sk, struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + u32 in_flight; + + /* Check, sanitize, and record packets in flight after skb was sent. */ + in_flight = tcp_packets_in_flight(tp) + tcp_skb_pcount(skb); + if (WARN_ONCE(in_flight > TCPCB_IN_FLIGHT_MAX, + "insane in_flight %u cc %s mss %u " + "cwnd %u pif %u %u %u %u\n", + in_flight, inet_csk(sk)->icsk_ca_ops->name, + tp->mss_cache, tp->snd_cwnd, + tp->packets_out, tp->retrans_out, + tp->sacked_out, tp->lost_out)) + in_flight = TCPCB_IN_FLIGHT_MAX; + TCP_SKB_CB(skb)->tx.in_flight = in_flight; +} + /* Snapshot the current delivery information in the skb, to generate * a rate sample later when the skb is (s)acked in tcp_rate_skb_delivered(). */ @@ -66,7 +84,9 @@ void tcp_rate_skb_sent(struct sock *sk, struct sk_buff *skb) TCP_SKB_CB(skb)->tx.delivered_mstamp = tp->delivered_mstamp; TCP_SKB_CB(skb)->tx.delivered = tp->delivered; TCP_SKB_CB(skb)->tx.delivered_ce = tp->delivered_ce; + TCP_SKB_CB(skb)->tx.lost = tp->lost; TCP_SKB_CB(skb)->tx.is_app_limited = tp->app_limited ? 1 : 0; + tcp_set_tx_in_flight(sk, skb); } /* When an skb is sacked or acked, we fill in the rate sample with the (prior) @@ -91,18 +111,21 @@ void tcp_rate_skb_delivered(struct sock *sk, struct sk_buff *skb, if (!rs->prior_delivered || tcp_skb_sent_after(tx_tstamp, tp->first_tx_mstamp, scb->end_seq, rs->last_end_seq)) { + rs->prior_lost = scb->tx.lost; rs->prior_delivered_ce = scb->tx.delivered_ce; rs->prior_delivered = scb->tx.delivered; rs->prior_mstamp = scb->tx.delivered_mstamp; rs->is_app_limited = scb->tx.is_app_limited; rs->is_retrans = scb->sacked & TCPCB_RETRANS; + rs->tx_in_flight = scb->tx.in_flight; rs->last_end_seq = scb->end_seq; /* Record send time of most recently ACKed packet: */ tp->first_tx_mstamp = tx_tstamp; /* Find the duration of the "send phase" of this window: */ - rs->interval_us = tcp_stamp_us_delta(tp->first_tx_mstamp, - scb->tx.first_tx_mstamp); + rs->interval_us = tcp_stamp32_us_delta( + tp->first_tx_mstamp, + scb->tx.first_tx_mstamp); } /* Mark off the skb delivered once it's sacked to avoid being @@ -144,6 +167,7 @@ void tcp_rate_gen(struct sock *sk, u32 delivered, u32 lost, return; } rs->delivered = tp->delivered - rs->prior_delivered; + rs->lost = tp->lost - rs->prior_lost; rs->delivered_ce = tp->delivered_ce - rs->prior_delivered_ce; /* delivered_ce occupies less than 32 bits in the skb control block */ @@ -155,7 +179,7 @@ void tcp_rate_gen(struct sock *sk, u32 delivered, u32 lost, * longer phase. */ snd_us = rs->interval_us; /* send phase */ - ack_us = tcp_stamp_us_delta(tp->tcp_mstamp, + ack_us = tcp_stamp32_us_delta(tp->tcp_mstamp, rs->prior_mstamp); /* ack phase */ rs->interval_us = max(snd_us, ack_us); diff --git a/net/ipv4/tcp_timer.c b/net/ipv4/tcp_timer.c index b412ed88ccd9..d70f8b742b21 100644 --- a/net/ipv4/tcp_timer.c +++ b/net/ipv4/tcp_timer.c @@ -699,6 +699,7 @@ void tcp_write_timer_handler(struct sock *sk) return; } + tcp_rate_check_app_limited(sk); tcp_mstamp_refresh(tcp_sk(sk)); event = icsk->icsk_pending; -- 2.49.0.391.g4bbb303af6 From 9a73ad96afdfd2a8095026eff238741ec57574f0 Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Sun, 20 Apr 2025 15:32:17 +0200 Subject: [PATCH 5/9] cachy Signed-off-by: Peter Jung --- .../admin-guide/kernel-parameters.txt | 12 + Makefile | 8 + arch/x86/Kconfig.cpu | 367 +- arch/x86/Makefile | 89 +- arch/x86/include/asm/pci.h | 6 + arch/x86/include/asm/vermagic.h | 72 + arch/x86/pci/common.c | 7 +- block/Kconfig.iosched | 9 + block/Makefile | 8 + block/adios.c | 1342 +++++++ block/elevator.c | 8 + drivers/Makefile | 13 +- drivers/ata/ahci.c | 23 +- drivers/cpufreq/Kconfig.x86 | 2 - drivers/cpufreq/intel_pstate.c | 2 + drivers/gpu/drm/amd/amdgpu/amdgpu.h | 1 + drivers/gpu/drm/amd/amdgpu/amdgpu_atombios.c | 44 +- drivers/gpu/drm/amd/amdgpu/amdgpu_atombios.h | 1 + drivers/gpu/drm/amd/amdgpu/amdgpu_device.c | 6 +- drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c | 10 + drivers/gpu/drm/amd/amdgpu/amdgpu_i2c.c | 19 + drivers/gpu/drm/amd/amdgpu/amdgpu_i2c.h | 1 + drivers/gpu/drm/amd/amdgpu/amdgpu_mode.h | 1 + drivers/gpu/drm/amd/display/Kconfig | 6 + .../gpu/drm/amd/display/amdgpu_dm/amdgpu_dm.c | 69 +- .../gpu/drm/amd/display/amdgpu_dm/amdgpu_dm.h | 7 + .../amd/display/amdgpu_dm/amdgpu_dm_color.c | 2 +- .../amd/display/amdgpu_dm/amdgpu_dm_crtc.c | 6 +- .../amd/display/amdgpu_dm/amdgpu_dm_plane.c | 6 +- .../drm/amd/display/dc/bios/bios_parser2.c | 13 +- .../drm/amd/display/dc/core/dc_link_exports.c | 6 + drivers/gpu/drm/amd/display/dc/dc.h | 3 + .../dc/resource/dce120/dce120_resource.c | 17 + drivers/gpu/drm/amd/pm/amdgpu_pm.c | 3 + drivers/gpu/drm/amd/pm/swsmu/amdgpu_smu.c | 14 +- drivers/input/evdev.c | 19 +- drivers/md/dm-crypt.c | 5 + drivers/media/v4l2-core/Kconfig | 5 + drivers/media/v4l2-core/Makefile | 2 + drivers/media/v4l2-core/v4l2loopback.c | 3292 +++++++++++++++++ drivers/media/v4l2-core/v4l2loopback.h | 98 + .../media/v4l2-core/v4l2loopback_formats.h | 445 +++ drivers/pci/controller/Makefile | 6 + drivers/pci/controller/intel-nvme-remap.c | 462 +++ drivers/pci/quirks.c | 101 + drivers/scsi/Kconfig | 2 + drivers/scsi/Makefile | 1 + drivers/scsi/vhba/Kconfig | 9 + drivers/scsi/vhba/Makefile | 4 + drivers/scsi/vhba/vhba.c | 1132 ++++++ include/linux/pagemap.h | 2 +- include/linux/user_namespace.h | 4 + include/linux/wait.h | 2 + init/Kconfig | 26 + kernel/Kconfig.hz | 24 + kernel/Kconfig.preempt | 2 +- kernel/fork.c | 14 + kernel/locking/rwsem.c | 4 +- kernel/sched/fair.c | 24 +- kernel/sched/sched.h | 2 +- kernel/sched/wait.c | 24 + kernel/sysctl.c | 12 + kernel/user_namespace.c | 7 + mm/Kconfig | 2 +- mm/compaction.c | 4 + mm/huge_memory.c | 4 + mm/page-writeback.c | 8 + mm/page_alloc.c | 4 + mm/swap.c | 5 + mm/vmpressure.c | 4 + mm/vmscan.c | 4 + net/ipv4/inet_connection_sock.c | 2 +- 72 files changed, 7871 insertions(+), 99 deletions(-) create mode 100644 block/adios.c create mode 100644 drivers/media/v4l2-core/v4l2loopback.c create mode 100644 drivers/media/v4l2-core/v4l2loopback.h create mode 100644 drivers/media/v4l2-core/v4l2loopback_formats.h create mode 100644 drivers/pci/controller/intel-nvme-remap.c create mode 100644 drivers/scsi/vhba/Kconfig create mode 100644 drivers/scsi/vhba/Makefile create mode 100644 drivers/scsi/vhba/vhba.c diff --git a/Documentation/admin-guide/kernel-parameters.txt b/Documentation/admin-guide/kernel-parameters.txt index aa7447f8837c..00e111ab9a0b 100644 --- a/Documentation/admin-guide/kernel-parameters.txt +++ b/Documentation/admin-guide/kernel-parameters.txt @@ -2277,6 +2277,9 @@ disable Do not enable intel_pstate as the default scaling driver for the supported processors + enable + Enable intel_pstate in-case "disable" was passed + previously in the kernel boot parameters active Use intel_pstate driver to bypass the scaling governors layer of cpufreq and provides it own @@ -4644,6 +4647,15 @@ nomsi [MSI] If the PCI_MSI kernel config parameter is enabled, this kernel boot option can be used to disable the use of MSI interrupts system-wide. + pcie_acs_override = + [PCIE] Override missing PCIe ACS support for: + downstream + All downstream ports - full ACS capabilities + multfunction + All multifunction devices - multifunction ACS subset + id:nnnn:nnnn + Specfic device - full ACS capabilities + Specified as vid:did (vendor/device ID) in hex noioapicquirk [APIC] Disable all boot interrupt quirks. Safety option to keep boot IRQs enabled. This should never be necessary. diff --git a/Makefile b/Makefile index 93870f58505f..9265996043dd 100644 --- a/Makefile +++ b/Makefile @@ -861,11 +861,19 @@ KBUILD_CFLAGS += -fno-delete-null-pointer-checks ifdef CONFIG_CC_OPTIMIZE_FOR_PERFORMANCE KBUILD_CFLAGS += -O2 KBUILD_RUSTFLAGS += -Copt-level=2 +else ifdef CONFIG_CC_OPTIMIZE_FOR_PERFORMANCE_O3 +KBUILD_CFLAGS += -O3 +KBUILD_RUSTFLAGS += -Copt-level=3 else ifdef CONFIG_CC_OPTIMIZE_FOR_SIZE KBUILD_CFLAGS += -Os KBUILD_RUSTFLAGS += -Copt-level=s endif +# Perform swing modulo scheduling immediately before the first scheduling pass. +# This pass looks at innermost loops and reorders their instructions by +# overlapping different iterations. +KBUILD_CFLAGS += $(call cc-option,-fmodulo-sched -fmodulo-sched-allow-regmoves -fivopts -fmodulo-sched) + # Always set `debug-assertions` and `overflow-checks` because their default # depends on `opt-level` and `debug-assertions`, respectively. KBUILD_RUSTFLAGS += -Cdebug-assertions=$(if $(CONFIG_RUST_DEBUG_ASSERTIONS),y,n) diff --git a/arch/x86/Kconfig.cpu b/arch/x86/Kconfig.cpu index 9bade26f6c67..00ae7c345e5c 100644 --- a/arch/x86/Kconfig.cpu +++ b/arch/x86/Kconfig.cpu @@ -155,9 +155,8 @@ config MPENTIUM4 -Paxville -Dempsey - config MK6 - bool "K6/K6-II/K6-III" + bool "AMD K6/K6-II/K6-III" depends on X86_32 help Select this for an AMD K6-family processor. Enables use of @@ -165,7 +164,7 @@ config MK6 flags to GCC. config MK7 - bool "Athlon/Duron/K7" + bool "AMD Athlon/Duron/K7" depends on X86_32 help Select this for an AMD Athlon K7-family processor. Enables use of @@ -173,12 +172,114 @@ config MK7 flags to GCC. config MK8 - bool "Opteron/Athlon64/Hammer/K8" + bool "AMD Opteron/Athlon64/Hammer/K8" help Select this for an AMD Opteron or Athlon64 Hammer-family processor. Enables use of some extended instructions, and passes appropriate optimization flags to GCC. +config MK8SSE3 + bool "AMD Opteron/Athlon64/Hammer/K8 with SSE3" + help + Select this for improved AMD Opteron or Athlon64 Hammer-family processors. + Enables use of some extended instructions, and passes appropriate + optimization flags to GCC. + +config MK10 + bool "AMD 61xx/7x50/PhenomX3/X4/II/K10" + help + Select this for an AMD 61xx Eight-Core Magny-Cours, Athlon X2 7x50, + Phenom X3/X4/II, Athlon II X2/X3/X4, or Turion II-family processor. + Enables use of some extended instructions, and passes appropriate + optimization flags to GCC. + +config MBARCELONA + bool "AMD Barcelona" + help + Select this for AMD Family 10h Barcelona processors. + + Enables -march=barcelona + +config MBOBCAT + bool "AMD Bobcat" + help + Select this for AMD Family 14h Bobcat processors. + + Enables -march=btver1 + +config MJAGUAR + bool "AMD Jaguar" + help + Select this for AMD Family 16h Jaguar processors. + + Enables -march=btver2 + +config MBULLDOZER + bool "AMD Bulldozer" + help + Select this for AMD Family 15h Bulldozer processors. + + Enables -march=bdver1 + +config MPILEDRIVER + bool "AMD Piledriver" + help + Select this for AMD Family 15h Piledriver processors. + + Enables -march=bdver2 + +config MSTEAMROLLER + bool "AMD Steamroller" + help + Select this for AMD Family 15h Steamroller processors. + + Enables -march=bdver3 + +config MEXCAVATOR + bool "AMD Excavator" + help + Select this for AMD Family 15h Excavator processors. + + Enables -march=bdver4 + +config MZEN + bool "AMD Zen" + help + Select this for AMD Family 17h Zen processors. + + Enables -march=znver1 + +config MZEN2 + bool "AMD Zen 2" + help + Select this for AMD Family 17h Zen 2 processors. + + Enables -march=znver2 + +config MZEN3 + bool "AMD Zen 3" + depends on (CC_IS_GCC && GCC_VERSION >= 100300) || (CC_IS_CLANG && CLANG_VERSION >= 120000) + help + Select this for AMD Family 19h Zen 3 processors. + + Enables -march=znver3 + +config MZEN4 + bool "AMD Zen 4" + depends on (CC_IS_GCC && GCC_VERSION >= 130000) || (CC_IS_CLANG && CLANG_VERSION >= 160000) + help + Select this for AMD Family 19h Zen 4 processors. + + Enables -march=znver4 + +config MZEN5 + bool "AMD Zen 5" + depends on (CC_IS_GCC && GCC_VERSION > 140000) || (CC_IS_CLANG && CLANG_VERSION >= 190100) + help + Select this for AMD Family 19h Zen 5 processors. + + Enables -march=znver5 + config MCRUSOE bool "Crusoe" depends on X86_32 @@ -269,8 +370,17 @@ config MPSC using the cpu family field in /proc/cpuinfo. Family 15 is an older Xeon, Family 6 a newer one. +config MATOM + bool "Intel Atom" + help + + Select this for the Intel Atom platform. Intel Atom CPUs have an + in-order pipelining architecture and thus can benefit from + accordingly optimized code. Use a recent GCC with specific Atom + support in order to fully benefit from selecting this option. + config MCORE2 - bool "Core 2/newer Xeon" + bool "Intel Core 2" help Select this for Intel Core 2 and newer Core 2 Xeons (Xeon 51xx and @@ -278,14 +388,199 @@ config MCORE2 family in /proc/cpuinfo. Newer ones have 6 and older ones 15 (not a typo) -config MATOM - bool "Intel Atom" + Enables -march=core2 + +config MNEHALEM + bool "Intel Nehalem" help - Select this for the Intel Atom platform. Intel Atom CPUs have an - in-order pipelining architecture and thus can benefit from - accordingly optimized code. Use a recent GCC with specific Atom - support in order to fully benefit from selecting this option. + Select this for 1st Gen Core processors in the Nehalem family. + + Enables -march=nehalem + +config MWESTMERE + bool "Intel Westmere" + help + + Select this for the Intel Westmere formerly Nehalem-C family. + + Enables -march=westmere + +config MSILVERMONT + bool "Intel Silvermont" + help + + Select this for the Intel Silvermont platform. + + Enables -march=silvermont + +config MGOLDMONT + bool "Intel Goldmont" + help + + Select this for the Intel Goldmont platform including Apollo Lake and Denverton. + + Enables -march=goldmont + +config MGOLDMONTPLUS + bool "Intel Goldmont Plus" + help + + Select this for the Intel Goldmont Plus platform including Gemini Lake. + + Enables -march=goldmont-plus + +config MSANDYBRIDGE + bool "Intel Sandy Bridge" + help + + Select this for 2nd Gen Core processors in the Sandy Bridge family. + + Enables -march=sandybridge + +config MIVYBRIDGE + bool "Intel Ivy Bridge" + help + + Select this for 3rd Gen Core processors in the Ivy Bridge family. + + Enables -march=ivybridge + +config MHASWELL + bool "Intel Haswell" + help + + Select this for 4th Gen Core processors in the Haswell family. + + Enables -march=haswell + +config MBROADWELL + bool "Intel Broadwell" + help + + Select this for 5th Gen Core processors in the Broadwell family. + + Enables -march=broadwell + +config MSKYLAKE + bool "Intel Skylake" + help + + Select this for 6th Gen Core processors in the Skylake family. + + Enables -march=skylake + +config MSKYLAKEX + bool "Intel Skylake X" + help + + Select this for 6th Gen Core processors in the Skylake X family. + + Enables -march=skylake-avx512 + +config MCANNONLAKE + bool "Intel Cannon Lake" + help + + Select this for 8th Gen Core processors + + Enables -march=cannonlake + +config MICELAKE_CLIENT + bool "Intel Ice Lake (Client)" + help + + Select this for 10th Gen Core processors in the Ice Lake family. + + Enables -march=icelake-client + +config MICELAKE_SERVER + bool "Intel Ice lake (Server)" + help + + Select this for the 3rd Gen Xeon processors in the Ice lake family. + + Enables -march=icelake-server + +config MCASCADELAKE + bool "Intel Cascade Lake" + help + + Select this for Xeon processors in the Cascade Lake family. + + Enables -march=cascadelake + +config MCOOPERLAKE + bool "Intel Cooper Lake" + depends on (CC_IS_GCC && GCC_VERSION > 100100) || (CC_IS_CLANG && CLANG_VERSION >= 100000) + help + + Select this for Xeon processors in the Cooper Lake family. + + Enables -march=cooperlake + +config MTIGERLAKE + bool "Intel Tiger Lake" + depends on (CC_IS_GCC && GCC_VERSION > 100100) || (CC_IS_CLANG && CLANG_VERSION >= 100000) + help + + Select this for third-generation 10 nm process processors in the Tiger Lake family. + + Enables -march=tigerlake + +config MSAPPHIRERAPIDS + bool "Intel Sapphire Rapids" + depends on (CC_IS_GCC && GCC_VERSION > 110000) || (CC_IS_CLANG && CLANG_VERSION >= 120000) + help + + Select this for fourth-generation 10 nm process processors in the Sapphire Rapids family. + + Enables -march=sapphirerapids + +config MROCKETLAKE + bool "Intel Rocket Lake" + depends on (CC_IS_GCC && GCC_VERSION > 110000) || (CC_IS_CLANG && CLANG_VERSION >= 120000) + help + + Select this for eleventh-generation processors in the Rocket Lake family. + + Enables -march=rocketlake + +config MALDERLAKE + bool "Intel Alder Lake" + depends on (CC_IS_GCC && GCC_VERSION > 110000) || (CC_IS_CLANG && CLANG_VERSION >= 120000) + help + + Select this for twelfth-generation processors in the Alder Lake family. + + Enables -march=alderlake + +config MRAPTORLAKE + bool "Intel Raptor Lake" + depends on (CC_IS_GCC && GCC_VERSION >= 130000) || (CC_IS_CLANG && CLANG_VERSION >= 150500) + help + + Select this for thirteenth-generation processors in the Raptor Lake family. + + Enables -march=raptorlake + +config MMETEORLAKE + bool "Intel Meteor Lake" + depends on (CC_IS_GCC && GCC_VERSION >= 130000) || (CC_IS_CLANG && CLANG_VERSION >= 150500) + help + + Select this for fourteenth-generation processors in the Meteor Lake family. + + Enables -march=meteorlake + +config MEMERALDRAPIDS + bool "Intel Emerald Rapids" + depends on (CC_IS_GCC && GCC_VERSION > 130000) || (CC_IS_CLANG && CLANG_VERSION >= 150500) + help + + Select this for fifth-generation 10 nm process processors in the Emerald Rapids family. + + Enables -march=emeraldrapids config GENERIC_CPU bool "Generic-x86-64" @@ -294,6 +589,26 @@ config GENERIC_CPU Generic x86-64 CPU. Run equally well on all x86-64 CPUs. +config MNATIVE_INTEL + bool "Intel-Native optimizations autodetected by the compiler" + help + + Clang 3.8, GCC 4.2 and above support -march=native, which automatically detects + the optimum settings to use based on your processor. Do NOT use this + for AMD CPUs. Intel Only! + + Enables -march=native + +config MNATIVE_AMD + bool "AMD-Native optimizations autodetected by the compiler" + help + + Clang 3.8, GCC 4.2 and above support -march=native, which automatically detects + the optimum settings to use based on your processor. Do NOT use this + for Intel CPUs. AMD Only! + + Enables -march=native + endchoice config X86_GENERIC @@ -308,6 +623,30 @@ config X86_GENERIC This is really intended for distributors who need more generic optimizations. +config X86_64_VERSION + int "x86-64 compiler ISA level" + range 1 4 + depends on (CC_IS_GCC && GCC_VERSION > 110000) || (CC_IS_CLANG && CLANG_VERSION >= 120000) + depends on X86_64 && GENERIC_CPU + help + Specify a specific x86-64 compiler ISA level. + + There are three x86-64 ISA levels that work on top of + the x86-64 baseline, namely: x86-64-v2, x86-64-v3, and x86-64-v4. + + x86-64-v2 brings support for vector instructions up to Streaming SIMD + Extensions 4.2 (SSE4.2) and Supplemental Streaming SIMD Extensions 3 + (SSSE3), the POPCNT instruction, and CMPXCHG16B. + + x86-64-v3 adds vector instructions up to AVX2, MOVBE, and additional + bit-manipulation instructions. + + x86-64-v4 is not included since the kernel does not use AVX512 instructions + + You can find the best version for your CPU by running one of the following: + /lib/ld-linux-x86-64.so.2 --help | grep supported + /lib64/ld-linux-x86-64.so.2 --help | grep supported + # # Define implied options from the CPU selection here config X86_INTERNODE_CACHE_SHIFT @@ -318,7 +657,7 @@ config X86_INTERNODE_CACHE_SHIFT config X86_L1_CACHE_SHIFT int default "7" if MPENTIUM4 || MPSC - default "6" if MK7 || MK8 || MPENTIUMM || MCORE2 || MATOM || MVIAC7 || X86_GENERIC || GENERIC_CPU + default "6" if MK7 || MK8 || MPENTIUMM || MCORE2 || MATOM || MVIAC7 || X86_GENERIC || GENERIC_CPU || MK8SSE3 || MK10 || MBARCELONA || MBOBCAT || MJAGUAR || MBULLDOZER || MPILEDRIVER || MSTEAMROLLER || MEXCAVATOR || MZEN || MZEN2 || MZEN3 || MZEN4 || MZEN5 || MNEHALEM || MWESTMERE || MSILVERMONT || MGOLDMONT || MGOLDMONTPLUS || MSANDYBRIDGE || MIVYBRIDGE || MHASWELL || MBROADWELL || MSKYLAKE || MSKYLAKEX || MCANNONLAKE || MICELAKE_CLIENT || MICELAKE_SERVER || MCASCADELAKE || MCOOPERLAKE || MTIGERLAKE || MSAPPHIRERAPIDS || MROCKETLAKE || MALDERLAKE || MRAPTORLAKE || MMETEORLAKE || MEMERALDRAPIDS || MNATIVE_INTEL || MNATIVE_AMD default "4" if MELAN || M486SX || M486 || MGEODEGX1 default "5" if MWINCHIP3D || MWINCHIPC6 || MCRUSOE || MEFFICEON || MCYRIXIII || MK6 || MPENTIUMIII || MPENTIUMII || M686 || M586MMX || M586TSC || M586 || MVIAC3_2 || MGEODE_LX @@ -336,11 +675,11 @@ config X86_ALIGNMENT_16 config X86_INTEL_USERCOPY def_bool y - depends on MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M586MMX || X86_GENERIC || MK8 || MK7 || MEFFICEON || MCORE2 + depends on MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M586MMX || X86_GENERIC || MK8 || MK7 || MEFFICEON || MCORE2 || MNEHALEM || MWESTMERE || MSILVERMONT || MGOLDMONT || MGOLDMONTPLUS || MSANDYBRIDGE || MIVYBRIDGE || MHASWELL || MBROADWELL || MSKYLAKE || MSKYLAKEX || MCANNONLAKE || MICELAKE_CLIENT || MICELAKE_SERVER || MCASCADELAKE || MCOOPERLAKE || MTIGERLAKE || MSAPPHIRERAPIDS || MROCKETLAKE || MALDERLAKE || MRAPTORLAKE || MMETEORLAKE || MEMERALDRAPIDS || MNATIVE_INTEL config X86_USE_PPRO_CHECKSUM def_bool y - depends on MWINCHIP3D || MWINCHIPC6 || MCYRIXIII || MK7 || MK6 || MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M686 || MK8 || MVIAC3_2 || MVIAC7 || MEFFICEON || MGEODE_LX || MCORE2 || MATOM + depends on MWINCHIP3D || MWINCHIPC6 || MCYRIXIII || MK7 || MK6 || MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M686 || MK8 || MVIAC3_2 || MVIAC7 || MEFFICEON || MGEODE_LX || MCORE2 || MATOM || MK8SSE3 || MK10 || MBARCELONA || MBOBCAT || MJAGUAR || MBULLDOZER || MPILEDRIVER || MSTEAMROLLER || MEXCAVATOR || MZEN || MZEN2 || MZEN3 || MZEN4 || MZEN5 || MNEHALEM || MWESTMERE || MSILVERMONT || MGOLDMONT || MGOLDMONTPLUS || MSANDYBRIDGE || MIVYBRIDGE || MHASWELL || MBROADWELL || MSKYLAKE || MSKYLAKEX || MCANNONLAKE || MICELAKE_CLIENT || MICELAKE_SERVER || MCASCADELAKE || MCOOPERLAKE || MTIGERLAKE || MSAPPHIRERAPIDS || MROCKETLAKE || MALDERLAKE || MRAPTORLAKE || MMETEORLAKE || MEMERALDRAPIDS || MNATIVE_INTEL || MNATIVE_AMD # # P6_NOPs are a relatively minor optimization that require a family >= diff --git a/arch/x86/Makefile b/arch/x86/Makefile index 5b773b34768d..5f57fd8750c6 100644 --- a/arch/x86/Makefile +++ b/arch/x86/Makefile @@ -182,15 +182,98 @@ else cflags-$(CONFIG_MK8) += -march=k8 cflags-$(CONFIG_MPSC) += -march=nocona cflags-$(CONFIG_MCORE2) += -march=core2 - cflags-$(CONFIG_MATOM) += -march=atom - cflags-$(CONFIG_GENERIC_CPU) += -mtune=generic + cflags-$(CONFIG_MATOM) += -march=bonnell + ifeq ($(CONFIG_X86_64_VERSION),1) + cflags-$(CONFIG_GENERIC_CPU) += -mtune=generic + rustflags-$(CONFIG_GENERIC_CPU) += -Ztune-cpu=generic + else + cflags-$(CONFIG_GENERIC_CPU) += -march=x86-64-v$(CONFIG_X86_64_VERSION) + rustflags-$(CONFIG_GENERIC_CPU) += -Ctarget-cpu=x86-64-v$(CONFIG_X86_64_VERSION) + endif + cflags-$(CONFIG_MK8SSE3) += -march=k8-sse3 + cflags-$(CONFIG_MK10) += -march=amdfam10 + cflags-$(CONFIG_MBARCELONA) += -march=barcelona + cflags-$(CONFIG_MBOBCAT) += -march=btver1 + cflags-$(CONFIG_MJAGUAR) += -march=btver2 + cflags-$(CONFIG_MBULLDOZER) += -march=bdver1 + cflags-$(CONFIG_MPILEDRIVER) += -march=bdver2 -mno-tbm + cflags-$(CONFIG_MSTEAMROLLER) += -march=bdver3 -mno-tbm + cflags-$(CONFIG_MEXCAVATOR) += -march=bdver4 -mno-tbm + cflags-$(CONFIG_MZEN) += -march=znver1 + cflags-$(CONFIG_MZEN2) += -march=znver2 + cflags-$(CONFIG_MZEN3) += -march=znver3 + cflags-$(CONFIG_MZEN4) += -march=znver4 + cflags-$(CONFIG_MZEN5) += -march=znver5 + cflags-$(CONFIG_MNATIVE_INTEL) += -march=native + cflags-$(CONFIG_MNATIVE_AMD) += -march=native -mno-tbm + cflags-$(CONFIG_MNEHALEM) += -march=nehalem + cflags-$(CONFIG_MWESTMERE) += -march=westmere + cflags-$(CONFIG_MSILVERMONT) += -march=silvermont + cflags-$(CONFIG_MGOLDMONT) += -march=goldmont + cflags-$(CONFIG_MGOLDMONTPLUS) += -march=goldmont-plus + cflags-$(CONFIG_MSANDYBRIDGE) += -march=sandybridge + cflags-$(CONFIG_MIVYBRIDGE) += -march=ivybridge + cflags-$(CONFIG_MHASWELL) += -march=haswell + cflags-$(CONFIG_MBROADWELL) += -march=broadwell + cflags-$(CONFIG_MSKYLAKE) += -march=skylake + cflags-$(CONFIG_MSKYLAKEX) += -march=skylake-avx512 + cflags-$(CONFIG_MCANNONLAKE) += -march=cannonlake + cflags-$(CONFIG_MICELAKE_CLIENT) += -march=icelake-client + cflags-$(CONFIG_MICELAKE_SERVER) += -march=icelake-server + cflags-$(CONFIG_MCASCADELAKE) += -march=cascadelake + cflags-$(CONFIG_MCOOPERLAKE) += -march=cooperlake + cflags-$(CONFIG_MTIGERLAKE) += -march=tigerlake + cflags-$(CONFIG_MSAPPHIRERAPIDS) += -march=sapphirerapids + cflags-$(CONFIG_MROCKETLAKE) += -march=rocketlake + cflags-$(CONFIG_MALDERLAKE) += -march=alderlake + cflags-$(CONFIG_MRAPTORLAKE) += -march=raptorlake + cflags-$(CONFIG_MMETEORLAKE) += -march=meteorlake + cflags-$(CONFIG_MEMERALDRAPIDS) += -march=emeraldrapids KBUILD_CFLAGS += $(cflags-y) rustflags-$(CONFIG_MK8) += -Ctarget-cpu=k8 rustflags-$(CONFIG_MPSC) += -Ctarget-cpu=nocona rustflags-$(CONFIG_MCORE2) += -Ctarget-cpu=core2 rustflags-$(CONFIG_MATOM) += -Ctarget-cpu=atom - rustflags-$(CONFIG_GENERIC_CPU) += -Ztune-cpu=generic + rustflags-$(CONFIG_MK8SSE3) += -Ctarget-cpu=k8-sse3 + rustflags-$(CONFIG_MK10) += -Ctarget-cpu=amdfam10 + rustflags-$(CONFIG_MBARCELONA) += -Ctarget-cpu=barcelona + rustflags-$(CONFIG_MBOBCAT) += -Ctarget-cpu=btver1 + rustflags-$(CONFIG_MJAGUAR) += -Ctarget-cpu=btver2 + rustflags-$(CONFIG_MBULLDOZER) += -Ctarget-cpu=bdver1 + rustflags-$(CONFIG_MPILEDRIVER) += -Ctarget-cpu=bdver2 + rustflags-$(CONFIG_MSTEAMROLLER) += -Ctarget-cpu=bdver3 + rustflags-$(CONFIG_MEXCAVATOR) += -Ctarget-cpu=bdver4 + rustflags-$(CONFIG_MZEN) += -Ctarget-cpu=znver1 + rustflags-$(CONFIG_MZEN2) += -Ctarget-cpu=znver2 + rustflags-$(CONFIG_MZEN3) += -Ctarget-cpu=znver3 + rustflags-$(CONFIG_MZEN4) += -Ctarget-cpu=znver4 + rustflags-$(CONFIG_MZEN5) += -Ctarget-cpu=znver5 + rustflags-$(CONFIG_MNATIVE_INTEL) += -Ctarget-cpu=native + rustflags-$(CONFIG_MNATIVE_AMD) += -Ctarget-cpu=native + rustflags-$(CONFIG_MNEHALEM) += -Ctarget-cpu=nehalem + rustflags-$(CONFIG_MWESTMERE) += -Ctarget-cpu=westmere + rustflags-$(CONFIG_MSILVERMONT) += -Ctarget-cpu=silvermont + rustflags-$(CONFIG_MGOLDMONT) += -Ctarget-cpu=goldmont + rustflags-$(CONFIG_MGOLDMONTPLUS) += -Ctarget-cpu=goldmont-plus + rustflags-$(CONFIG_MSANDYBRIDGE) += -Ctarget-cpu=sandybridge + rustflags-$(CONFIG_MIVYBRIDGE) += -Ctarget-cpu=ivybridge + rustflags-$(CONFIG_MHASWELL) += -Ctarget-cpu=haswell + rustflags-$(CONFIG_MBROADWELL) += -Ctarget-cpu=broadwell + rustflags-$(CONFIG_MSKYLAKE) += -Ctarget-cpu=skylake + rustflags-$(CONFIG_MSKYLAKEX) += -Ctarget-cpu=skylake-avx512 + rustflags-$(CONFIG_MCANNONLAKE) += -Ctarget-cpu=cannonlake + rustflags-$(CONFIG_MICELAKE_CLIENT) += -Ctarget-cpu=icelake-client + rustflags-$(CONFIG_MICELAKE_SERVER) += -Ctarget-cpu=icelake-server + rustflags-$(CONFIG_MCASCADELAKE) += -Ctarget-cpu=cascadelake + rustflags-$(CONFIG_MCOOPERLAKE) += -Ctarget-cpu=cooperlake + rustflags-$(CONFIG_MTIGERLAKE) += -Ctarget-cpu=tigerlake + rustflags-$(CONFIG_MSAPPHIRERAPIDS) += -Ctarget-cpu=sapphirerapids + rustflags-$(CONFIG_MROCKETLAKE) += -Ctarget-cpu=rocketlake + rustflags-$(CONFIG_MALDERLAKE) += -Ctarget-cpu=alderlake + rustflags-$(CONFIG_MRAPTORLAKE) += -Ctarget-cpu=raptorlake + rustflags-$(CONFIG_MMETEORLAKE) += -Ctarget-cpu=meteorlake + rustflags-$(CONFIG_MEMERALDRAPIDS) += -Ctarget-cpu=emeraldrapids KBUILD_RUSTFLAGS += $(rustflags-y) KBUILD_CFLAGS += -mno-red-zone diff --git a/arch/x86/include/asm/pci.h b/arch/x86/include/asm/pci.h index b3ab80a03365..5e883b397ff3 100644 --- a/arch/x86/include/asm/pci.h +++ b/arch/x86/include/asm/pci.h @@ -26,6 +26,7 @@ struct pci_sysdata { #if IS_ENABLED(CONFIG_VMD) struct pci_dev *vmd_dev; /* VMD Device if in Intel VMD domain */ #endif + struct pci_dev *nvme_remap_dev; /* AHCI Device if NVME remapped bus */ }; extern int pci_routeirq; @@ -69,6 +70,11 @@ static inline bool is_vmd(struct pci_bus *bus) #define is_vmd(bus) false #endif /* CONFIG_VMD */ +static inline bool is_nvme_remap(struct pci_bus *bus) +{ + return to_pci_sysdata(bus)->nvme_remap_dev != NULL; +} + /* Can be used to override the logic in pci_scan_bus for skipping already-configured bus numbers - to be used for buggy BIOSes or architectures with incomplete PCI setup by the loader */ diff --git a/arch/x86/include/asm/vermagic.h b/arch/x86/include/asm/vermagic.h index 75884d2cdec3..2fdae271f47f 100644 --- a/arch/x86/include/asm/vermagic.h +++ b/arch/x86/include/asm/vermagic.h @@ -17,6 +17,56 @@ #define MODULE_PROC_FAMILY "586MMX " #elif defined CONFIG_MCORE2 #define MODULE_PROC_FAMILY "CORE2 " +#elif defined CONFIG_MNATIVE_INTEL +#define MODULE_PROC_FAMILY "NATIVE_INTEL " +#elif defined CONFIG_MNATIVE_AMD +#define MODULE_PROC_FAMILY "NATIVE_AMD " +#elif defined CONFIG_MNEHALEM +#define MODULE_PROC_FAMILY "NEHALEM " +#elif defined CONFIG_MWESTMERE +#define MODULE_PROC_FAMILY "WESTMERE " +#elif defined CONFIG_MSILVERMONT +#define MODULE_PROC_FAMILY "SILVERMONT " +#elif defined CONFIG_MGOLDMONT +#define MODULE_PROC_FAMILY "GOLDMONT " +#elif defined CONFIG_MGOLDMONTPLUS +#define MODULE_PROC_FAMILY "GOLDMONTPLUS " +#elif defined CONFIG_MSANDYBRIDGE +#define MODULE_PROC_FAMILY "SANDYBRIDGE " +#elif defined CONFIG_MIVYBRIDGE +#define MODULE_PROC_FAMILY "IVYBRIDGE " +#elif defined CONFIG_MHASWELL +#define MODULE_PROC_FAMILY "HASWELL " +#elif defined CONFIG_MBROADWELL +#define MODULE_PROC_FAMILY "BROADWELL " +#elif defined CONFIG_MSKYLAKE +#define MODULE_PROC_FAMILY "SKYLAKE " +#elif defined CONFIG_MSKYLAKEX +#define MODULE_PROC_FAMILY "SKYLAKEX " +#elif defined CONFIG_MCANNONLAKE +#define MODULE_PROC_FAMILY "CANNONLAKE " +#elif defined CONFIG_MICELAKE_CLIENT +#define MODULE_PROC_FAMILY "ICELAKE_CLIENT " +#elif defined CONFIG_MICELAKE_SERVER +#define MODULE_PROC_FAMILY "ICELAKE_SERVER " +#elif defined CONFIG_MCASCADELAKE +#define MODULE_PROC_FAMILY "CASCADELAKE " +#elif defined CONFIG_MCOOPERLAKE +#define MODULE_PROC_FAMILY "COOPERLAKE " +#elif defined CONFIG_MTIGERLAKE +#define MODULE_PROC_FAMILY "TIGERLAKE " +#elif defined CONFIG_MSAPPHIRERAPIDS +#define MODULE_PROC_FAMILY "SAPPHIRERAPIDS " +#elif defined CONFIG_ROCKETLAKE +#define MODULE_PROC_FAMILY "ROCKETLAKE " +#elif defined CONFIG_MALDERLAKE +#define MODULE_PROC_FAMILY "ALDERLAKE " +#elif defined CONFIG_MRAPTORLAKE +#define MODULE_PROC_FAMILY "RAPTORLAKE " +#elif defined CONFIG_MMETEORLAKE +#define MODULE_PROC_FAMILY "METEORLAKE " +#elif defined CONFIG_MEMERALDRAPIDS +#define MODULE_PROC_FAMILY "EMERALDRAPIDS " #elif defined CONFIG_MATOM #define MODULE_PROC_FAMILY "ATOM " #elif defined CONFIG_M686 @@ -35,6 +85,28 @@ #define MODULE_PROC_FAMILY "K7 " #elif defined CONFIG_MK8 #define MODULE_PROC_FAMILY "K8 " +#elif defined CONFIG_MK8SSE3 +#define MODULE_PROC_FAMILY "K8SSE3 " +#elif defined CONFIG_MK10 +#define MODULE_PROC_FAMILY "K10 " +#elif defined CONFIG_MBARCELONA +#define MODULE_PROC_FAMILY "BARCELONA " +#elif defined CONFIG_MBOBCAT +#define MODULE_PROC_FAMILY "BOBCAT " +#elif defined CONFIG_MBULLDOZER +#define MODULE_PROC_FAMILY "BULLDOZER " +#elif defined CONFIG_MPILEDRIVER +#define MODULE_PROC_FAMILY "PILEDRIVER " +#elif defined CONFIG_MSTEAMROLLER +#define MODULE_PROC_FAMILY "STEAMROLLER " +#elif defined CONFIG_MJAGUAR +#define MODULE_PROC_FAMILY "JAGUAR " +#elif defined CONFIG_MEXCAVATOR +#define MODULE_PROC_FAMILY "EXCAVATOR " +#elif defined CONFIG_MZEN +#define MODULE_PROC_FAMILY "ZEN " +#elif defined CONFIG_MZEN2 +#define MODULE_PROC_FAMILY "ZEN2 " #elif defined CONFIG_MELAN #define MODULE_PROC_FAMILY "ELAN " #elif defined CONFIG_MCRUSOE diff --git a/arch/x86/pci/common.c b/arch/x86/pci/common.c index ddb798603201..7c20387d8202 100644 --- a/arch/x86/pci/common.c +++ b/arch/x86/pci/common.c @@ -723,12 +723,15 @@ int pci_ext_cfg_avail(void) return 0; } -#if IS_ENABLED(CONFIG_VMD) struct pci_dev *pci_real_dma_dev(struct pci_dev *dev) { +#if IS_ENABLED(CONFIG_VMD) if (is_vmd(dev->bus)) return to_pci_sysdata(dev->bus)->vmd_dev; +#endif + + if (is_nvme_remap(dev->bus)) + return to_pci_sysdata(dev->bus)->nvme_remap_dev; return dev; } -#endif diff --git a/block/Kconfig.iosched b/block/Kconfig.iosched index 27f11320b8d1..79fd5da5dd16 100644 --- a/block/Kconfig.iosched +++ b/block/Kconfig.iosched @@ -16,6 +16,15 @@ config MQ_IOSCHED_KYBER synchronous writes, it will self-tune queue depths to achieve that goal. +config MQ_IOSCHED_ADIOS + tristate "Adaptive Deadline I/O scheduler" + default m + help + ADIOS is a multi-queue I/O scheduler for the Linux kernel, based on + mq-deadline and Kyber, with learning-based adaptive latency control. + It aims to provide low latency for synchronous requests while + maintaining high throughput for asynchronous requests and bulk I/O. + config IOSCHED_BFQ tristate "BFQ I/O scheduler" select BLK_ICQ diff --git a/block/Makefile b/block/Makefile index 33748123710b..a67f651a7846 100644 --- a/block/Makefile +++ b/block/Makefile @@ -23,6 +23,7 @@ obj-$(CONFIG_BLK_CGROUP_IOLATENCY) += blk-iolatency.o obj-$(CONFIG_BLK_CGROUP_IOCOST) += blk-iocost.o obj-$(CONFIG_MQ_IOSCHED_DEADLINE) += mq-deadline.o obj-$(CONFIG_MQ_IOSCHED_KYBER) += kyber-iosched.o +obj-$(CONFIG_MQ_IOSCHED_ADIOS) += adios.o bfq-y := bfq-iosched.o bfq-wf2q.o bfq-cgroup.o obj-$(CONFIG_IOSCHED_BFQ) += bfq.o @@ -36,3 +37,10 @@ obj-$(CONFIG_BLK_INLINE_ENCRYPTION) += blk-crypto.o blk-crypto-profile.o \ blk-crypto-sysfs.o obj-$(CONFIG_BLK_INLINE_ENCRYPTION_FALLBACK) += blk-crypto-fallback.o obj-$(CONFIG_BLOCK_HOLDER_DEPRECATED) += holder.o + +all: + make -C /lib/modules/$(shell uname -r)/build M=$(PWD) modules + +clean: + make -C /lib/modules/$(shell uname -r)/build M=$(PWD) clean + diff --git a/block/adios.c b/block/adios.c new file mode 100644 index 000000000000..a35fcf6e4569 --- /dev/null +++ b/block/adios.c @@ -0,0 +1,1342 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * The Adaptive Deadline I/O Scheduler (ADIOS) + * Based on mq-deadline and Kyber, + * with learning-based adaptive latency control + * + * Copyright (C) 2025 Masahito Suzuki + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "elevator.h" +#include "blk.h" +#include "blk-mq.h" +#include "blk-mq-sched.h" + +#define ADIOS_VERSION "1.5.4" + +// Define operation types supported by ADIOS +enum adios_op_type { + ADIOS_READ = 0, + ADIOS_WRITE = 1, + ADIOS_DISCARD = 2, + ADIOS_OTHER = 3, + ADIOS_OPTYPES = 4, +}; + +// Global variable to control the latency +static u64 default_global_latency_window = 16000000ULL; +// Ratio below which batch queues should be refilled +static u8 default_bq_refill_below_ratio = 15; + +// Dynamic thresholds for shrinkage +static u32 default_lm_shrink_at_kreqs = 10000; +static u32 default_lm_shrink_at_gbytes = 100; +static u32 default_lm_shrink_resist = 2; + +// Latency targets for each operation type +static u64 default_latency_target[ADIOS_OPTYPES] = { + [ADIOS_READ] = 1ULL * NSEC_PER_MSEC, + [ADIOS_WRITE] = 2000ULL * NSEC_PER_MSEC, + [ADIOS_DISCARD] = 8000ULL * NSEC_PER_MSEC, + [ADIOS_OTHER] = 0ULL * NSEC_PER_MSEC, +}; + +// Maximum batch size limits for each operation type +static u32 default_batch_limit[ADIOS_OPTYPES] = { + [ADIOS_READ] = 24, + [ADIOS_WRITE] = 48, + [ADIOS_DISCARD] = 1, + [ADIOS_OTHER] = 1, +}; + +static u32 default_dl_prio[2] = {7, 0}; + +// Thresholds for latency model control +#define LM_BLOCK_SIZE_THRESHOLD 4096 +#define LM_SAMPLES_THRESHOLD 1024 +#define LM_INTERVAL_THRESHOLD 1500 +#define LM_OUTLIER_PERCENTILE 99 +#define LM_LAT_BUCKET_COUNT 64 + +// Structure to hold latency bucket data for small requests +struct latency_bucket_small { + u64 sum_latency; + u32 count; +}; + +// Structure to hold latency bucket data for large requests +struct latency_bucket_large { + u64 sum_latency; + u64 sum_block_size; + u32 count; +}; + +// Structure to hold the latency model context data +struct latency_model { + spinlock_t lock; + u64 base; + u64 slope; + u64 small_sum_delay; + u64 small_count; + u64 large_sum_delay; + u64 large_sum_bsize; + u64 last_update_jiffies; + + spinlock_t buckets_lock; + struct latency_bucket_small small_bucket[LM_LAT_BUCKET_COUNT]; + struct latency_bucket_large large_bucket[LM_LAT_BUCKET_COUNT]; + + u32 lm_shrink_at_kreqs; + u32 lm_shrink_at_gbytes; + u8 lm_shrink_resist; +}; + +#define ADIOS_BQ_PAGES 2 + +// Adios scheduler data +struct adios_data { + spinlock_t pq_lock; + struct list_head prio_queue; + + struct rb_root_cached dl_tree[2]; + spinlock_t lock; + u8 dl_queued; + s64 dl_bias; + s32 dl_prio[2]; + + u64 global_latency_window; + u64 latency_target[ADIOS_OPTYPES]; + u32 batch_limit[ADIOS_OPTYPES]; + u32 batch_actual_max_size[ADIOS_OPTYPES]; + u32 batch_actual_max_total; + u32 async_depth; + u8 bq_refill_below_ratio; + + u8 bq_page; + bool more_bq_ready; + struct list_head batch_queue[ADIOS_BQ_PAGES][ADIOS_OPTYPES]; + u32 batch_count[ADIOS_BQ_PAGES][ADIOS_OPTYPES]; + spinlock_t bq_lock; + + struct latency_model latency_model[ADIOS_OPTYPES]; + struct timer_list update_timer; + + atomic64_t total_pred_lat; + + struct kmem_cache *rq_data_pool; + struct kmem_cache *dl_group_pool; +}; + +// List of requests with the same deadline in the deadline-sorted tree +struct dl_group { + struct rb_node node; + struct list_head rqs; + u64 deadline; +} __attribute__((aligned(64))); + +// Structure to hold scheduler-specific data for each request +struct adios_rq_data { + struct list_head *dl_group; + struct list_head dl_node; + + struct request *rq; + u64 deadline; + u64 pred_lat; + u32 block_size; +} __attribute__((aligned(64))); + +static const int adios_prio_to_weight[40] = { + /* -20 */ 88761, 71755, 56483, 46273, 36291, + /* -15 */ 29154, 23254, 18705, 14949, 11916, + /* -10 */ 9548, 7620, 6100, 4904, 3906, + /* -5 */ 3121, 2501, 1991, 1586, 1277, + /* 0 */ 1024, 820, 655, 526, 423, + /* 5 */ 335, 272, 215, 172, 137, + /* 10 */ 110, 87, 70, 56, 45, + /* 15 */ 36, 29, 23, 18, 15, +}; + +// Count the number of entries in small buckets +static u32 lm_count_small_entries(struct latency_model *model) { + u32 total_count = 0; + for (u8 i = 0; i < LM_LAT_BUCKET_COUNT; i++) + total_count += model->small_bucket[i].count; + return total_count; +} + +// Update the small buckets in the latency model +static bool lm_update_small_buckets(struct latency_model *model, + u32 total_count, bool count_all) { + u64 sum_latency = 0; + u32 sum_count = 0; + u32 cumulative_count = 0, threshold_count = 0; + u8 outlier_threshold_bucket = 0; + u8 outlier_percentile = LM_OUTLIER_PERCENTILE; + u8 reduction; + + if (count_all) + outlier_percentile = 100; + + // Calculate the threshold count for outlier detection + threshold_count = (total_count * outlier_percentile) / 100; + + // Identify the bucket that corresponds to the outlier threshold + for (u8 i = 0; i < LM_LAT_BUCKET_COUNT; i++) { + cumulative_count += model->small_bucket[i].count; + if (cumulative_count >= threshold_count) { + outlier_threshold_bucket = i; + break; + } + } + + // Calculate the average latency, excluding outliers + for (u8 i = 0; i <= outlier_threshold_bucket; i++) { + struct latency_bucket_small *bucket = &model->small_bucket[i]; + if (i < outlier_threshold_bucket) { + sum_latency += bucket->sum_latency; + sum_count += bucket->count; + } else { + // The threshold bucket's contribution is proportional + u64 remaining_count = + threshold_count - (cumulative_count - bucket->count); + if (bucket->count > 0) { + sum_latency += + (bucket->sum_latency * remaining_count) / bucket->count; + sum_count += remaining_count; + } + } + } + + // Shrink the model if it reaches at the readjustment threshold + if (model->small_count >= 1000ULL * model->lm_shrink_at_kreqs) { + reduction = model->lm_shrink_resist; + if (model->small_count >> reduction) { + model->small_sum_delay -= model->small_sum_delay >> reduction; + model->small_count -= model->small_count >> reduction; + } + } + + // Accumulate the average latency into the statistics + model->small_sum_delay += sum_latency; + model->small_count += sum_count; + + // Reset small bucket information + memset(model->small_bucket, 0, + sizeof(model->small_bucket[0]) * LM_LAT_BUCKET_COUNT); + + return true; +} + +// Count the number of entries in large buckets +static u32 lm_count_large_entries(struct latency_model *model) { + u32 total_count = 0; + for (u8 i = 0; i < LM_LAT_BUCKET_COUNT; i++) + total_count += model->large_bucket[i].count; + return total_count; +} + +// Update the large buckets in the latency model +static bool lm_update_large_buckets( + struct latency_model *model, + u32 total_count, bool count_all) { + s64 sum_latency = 0; + u64 sum_block_size = 0, intercept; + u32 cumulative_count = 0, threshold_count = 0; + u8 outlier_threshold_bucket = 0; + u8 outlier_percentile = LM_OUTLIER_PERCENTILE; + u8 reduction; + + if (count_all) + outlier_percentile = 100; + + // Calculate the threshold count for outlier detection + threshold_count = (total_count * outlier_percentile) / 100; + + // Identify the bucket that corresponds to the outlier threshold + for (u8 i = 0; i < LM_LAT_BUCKET_COUNT; i++) { + cumulative_count += model->large_bucket[i].count; + if (cumulative_count >= threshold_count) { + outlier_threshold_bucket = i; + break; + } + } + + // Calculate the average latency and block size, excluding outliers + for (u8 i = 0; i <= outlier_threshold_bucket; i++) { + struct latency_bucket_large *bucket = &model->large_bucket[i]; + if (i < outlier_threshold_bucket) { + sum_latency += bucket->sum_latency; + sum_block_size += bucket->sum_block_size; + } else { + // The threshold bucket's contribution is proportional + u64 remaining_count = + threshold_count - (cumulative_count - bucket->count); + if (bucket->count > 0) { + sum_latency += + (bucket->sum_latency * remaining_count) / bucket->count; + sum_block_size += + (bucket->sum_block_size * remaining_count) / bucket->count; + } + } + } + + // Shrink the model if it reaches at the readjustment threshold + if (model->large_sum_bsize >= 0x40000000ULL * model->lm_shrink_at_gbytes) { + reduction = model->lm_shrink_resist; + if (model->large_sum_bsize >> reduction) { + model->large_sum_delay -= model->large_sum_delay >> reduction; + model->large_sum_bsize -= model->large_sum_bsize >> reduction; + } + } + + // Accumulate the average delay into the statistics + intercept = model->base * threshold_count; + if (sum_latency > intercept) + sum_latency -= intercept; + + model->large_sum_delay += sum_latency; + model->large_sum_bsize += sum_block_size; + + // Reset large bucket information + memset(model->large_bucket, 0, + sizeof(model->large_bucket[0]) * LM_LAT_BUCKET_COUNT); + + return true; +} + +// Update the latency model parameters and statistics +static void latency_model_update(struct latency_model *model) { + unsigned long flags; + u64 now; + u32 small_count, large_count; + bool time_elapsed; + bool small_processed = false, large_processed = false; + + guard(spinlock_irqsave)(&model->lock); + + spin_lock_irqsave(&model->buckets_lock, flags); + + // Whether enough time has elapsed since the last update + now = jiffies; + time_elapsed = unlikely(!model->base) || model->last_update_jiffies + + msecs_to_jiffies(LM_INTERVAL_THRESHOLD) <= now; + + // Count the number of entries in buckets + small_count = lm_count_small_entries(model); + large_count = lm_count_large_entries(model); + + // Update small buckets + if (small_count && (time_elapsed || + LM_SAMPLES_THRESHOLD <= small_count || !model->base)) + small_processed = lm_update_small_buckets( + model, small_count, !model->base); + // Update large buckets + if (large_count && (time_elapsed || + LM_SAMPLES_THRESHOLD <= large_count || !model->slope)) + large_processed = lm_update_large_buckets( + model, large_count, !model->slope); + + spin_unlock_irqrestore(&model->buckets_lock, flags); + + // Update the base parameter if small bucket was processed + if (small_processed && likely(model->small_count)) + model->base = div_u64(model->small_sum_delay, model->small_count); + + // Update the slope parameter if large bucket was processed + if (large_processed && likely(model->large_sum_bsize)) + model->slope = div_u64(model->large_sum_delay, + DIV_ROUND_UP_ULL(model->large_sum_bsize, 1024)); + + // Reset statistics and update last updated jiffies if time has elapsed + if (time_elapsed) + model->last_update_jiffies = now; +} + +// Determine the bucket index for a given measured and predicted latency +static u8 lm_input_bucket_index( + struct latency_model *model, u64 measured, u64 predicted) { + u8 bucket_index; + + if (measured < predicted * 2) + bucket_index = (measured * 20) / predicted; + else if (measured < predicted * 5) + bucket_index = (measured * 10) / predicted + 20; + else + bucket_index = (measured * 3) / predicted + 40; + + return bucket_index; +} + +// Input latency data into the latency model +static void latency_model_input(struct latency_model *model, + u32 block_size, u64 latency, u64 pred_lat) { + unsigned long flags; + u8 bucket_index; + + spin_lock_irqsave(&model->buckets_lock, flags); + + if (block_size <= LM_BLOCK_SIZE_THRESHOLD) { + // Handle small requests + bucket_index = lm_input_bucket_index(model, latency, model->base ?: 1); + + if (bucket_index >= LM_LAT_BUCKET_COUNT) + bucket_index = LM_LAT_BUCKET_COUNT - 1; + + model->small_bucket[bucket_index].count++; + model->small_bucket[bucket_index].sum_latency += latency; + + if (unlikely(!model->base)) { + spin_unlock_irqrestore(&model->buckets_lock, flags); + latency_model_update(model); + return; + } + } else { + // Handle large requests + if (!model->base || !pred_lat) { + spin_unlock_irqrestore(&model->buckets_lock, flags); + return; + } + + bucket_index = lm_input_bucket_index(model, latency, pred_lat); + + if (bucket_index >= LM_LAT_BUCKET_COUNT) + bucket_index = LM_LAT_BUCKET_COUNT - 1; + + model->large_bucket[bucket_index].count++; + model->large_bucket[bucket_index].sum_latency += latency; + model->large_bucket[bucket_index].sum_block_size += block_size; + } + + spin_unlock_irqrestore(&model->buckets_lock, flags); +} + +// Predict the latency for a given block size using the latency model +static u64 latency_model_predict(struct latency_model *model, u32 block_size) { + u64 result; + + guard(spinlock_irqsave)(&model->lock); + // Predict latency based on the model + result = model->base; + if (block_size > LM_BLOCK_SIZE_THRESHOLD) + result += model->slope * + DIV_ROUND_UP_ULL(block_size - LM_BLOCK_SIZE_THRESHOLD, 1024); + + return result; +} + +// Determine the type of operation based on request flags +static u8 adios_optype(struct request *rq) { + switch (rq->cmd_flags & REQ_OP_MASK) { + case REQ_OP_READ: + return ADIOS_READ; + case REQ_OP_WRITE: + return ADIOS_WRITE; + case REQ_OP_DISCARD: + return ADIOS_DISCARD; + default: + return ADIOS_OTHER; + } +} + +static inline u8 adios_optype_not_read(struct request *rq) { + return (rq->cmd_flags & REQ_OP_MASK) != REQ_OP_READ; +} + +// Helper function to retrieve adios_rq_data from a request +static inline struct adios_rq_data *get_rq_data(struct request *rq) { + return rq->elv.priv[0]; +} + +// Add a request to the deadline-sorted red-black tree +static void add_to_dl_tree( + struct adios_data *ad, bool dl_idx, struct request *rq) { + struct rb_root_cached *root = &ad->dl_tree[dl_idx]; + struct rb_node **link = &(root->rb_root.rb_node), *parent = NULL; + bool leftmost = true; + struct adios_rq_data *rd = get_rq_data(rq); + struct dl_group *dlg; + + rd->block_size = blk_rq_bytes(rq); + u8 optype = adios_optype(rq); + rd->pred_lat = + latency_model_predict(&ad->latency_model[optype], rd->block_size); + rd->deadline = + rq->start_time_ns + ad->latency_target[optype] + rd->pred_lat; + + while (*link) { + dlg = rb_entry(*link, struct dl_group, node); + s64 diff = rd->deadline - dlg->deadline; + + parent = *link; + if (diff < 0) { + link = &((*link)->rb_left); + } else if (diff > 0) { + link = &((*link)->rb_right); + leftmost = false; + } else { // diff == 0 + goto found; + } + } + + dlg = rb_entry_safe(parent, struct dl_group, node); + if (!dlg || dlg->deadline != rd->deadline) { + dlg = kmem_cache_zalloc(ad->dl_group_pool, GFP_ATOMIC); + if (!dlg) + return; + dlg->deadline = rd->deadline; + INIT_LIST_HEAD(&dlg->rqs); + rb_link_node(&dlg->node, parent, link); + rb_insert_color_cached(&dlg->node, root, leftmost); + } +found: + list_add_tail(&rd->dl_node, &dlg->rqs); + rd->dl_group = &dlg->rqs; + ad->dl_queued |= 1 << dl_idx; +} + +// Remove a request from the deadline-sorted red-black tree +static void del_from_dl_tree( + struct adios_data *ad, bool dl_idx, struct request *rq) { + struct rb_root_cached *root = &ad->dl_tree[dl_idx]; + struct adios_rq_data *rd = get_rq_data(rq); + struct dl_group *dlg = container_of(rd->dl_group, struct dl_group, rqs); + + list_del_init(&rd->dl_node); + if (list_empty(&dlg->rqs)) { + rb_erase_cached(&dlg->node, root); + kmem_cache_free(ad->dl_group_pool, dlg); + } + rd->dl_group = NULL; + + if (RB_EMPTY_ROOT(&ad->dl_tree[dl_idx].rb_root)) + ad->dl_queued &= ~(1 << dl_idx); +} + +// Remove a request from the scheduler +static void remove_request(struct adios_data *ad, struct request *rq) { + bool dl_idx = adios_optype_not_read(rq); + struct request_queue *q = rq->q; + struct adios_rq_data *rd = get_rq_data(rq); + + list_del_init(&rq->queuelist); + + // We might not be on the rbtree, if we are doing an insert merge + if (rd->dl_group) + del_from_dl_tree(ad, dl_idx, rq); + + elv_rqhash_del(q, rq); + if (q->last_merge == rq) + q->last_merge = NULL; +} + +// Convert a queue depth to the corresponding word depth for shallow allocation +static int to_word_depth(struct blk_mq_hw_ctx *hctx, unsigned int qdepth) { + struct sbitmap_queue *bt = &hctx->sched_tags->bitmap_tags; + const unsigned int nrr = hctx->queue->nr_requests; + + return ((qdepth << bt->sb.shift) + nrr - 1) / nrr; +} + +// Limit the depth of request allocation for asynchronous and write requests +static void adios_limit_depth(blk_opf_t opf, struct blk_mq_alloc_data *data) { + struct adios_data *ad = data->q->elevator->elevator_data; + + // Do not throttle synchronous reads + if (op_is_sync(opf) && !op_is_write(opf)) + return; + + data->shallow_depth = to_word_depth(data->hctx, ad->async_depth); +} + +// Update async_depth when the number of requests in the queue changes +static void adios_depth_updated(struct blk_mq_hw_ctx *hctx) { + struct request_queue *q = hctx->queue; + struct adios_data *ad = q->elevator->elevator_data; + struct blk_mq_tags *tags = hctx->sched_tags; + + ad->async_depth = q->nr_requests; + + sbitmap_queue_min_shallow_depth(&tags->bitmap_tags, 1); +} + +// Handle request merging after a merge operation +static void adios_request_merged(struct request_queue *q, struct request *req, + enum elv_merge type) { + bool dl_idx = adios_optype_not_read(req); + struct adios_data *ad = q->elevator->elevator_data; + + // if the merge was a front merge, we need to reposition request + if (type == ELEVATOR_FRONT_MERGE) { + del_from_dl_tree(ad, dl_idx, req); + add_to_dl_tree(ad, dl_idx, req); + } +} + +// Handle merging of requests after one has been merged into another +static void adios_merged_requests(struct request_queue *q, struct request *req, + struct request *next) { + struct adios_data *ad = q->elevator->elevator_data; + + lockdep_assert_held(&ad->lock); + + // kill knowledge of next, this one is a goner + remove_request(ad, next); +} + +// Try to merge a bio into an existing rq before associating it with an rq +static bool adios_bio_merge(struct request_queue *q, struct bio *bio, + unsigned int nr_segs) { + unsigned long flags; + struct adios_data *ad = q->elevator->elevator_data; + struct request *free = NULL; + bool ret; + + spin_lock_irqsave(&ad->lock, flags); + ret = blk_mq_sched_try_merge(q, bio, nr_segs, &free); + spin_unlock_irqrestore(&ad->lock, flags); + + if (free) + blk_mq_free_request(free); + + return ret; +} + +// Insert a request into the scheduler +static void insert_request(struct blk_mq_hw_ctx *hctx, struct request *rq, + blk_insert_t insert_flags, struct list_head *free) { + unsigned long flags; + bool dl_idx = adios_optype_not_read(rq); + struct request_queue *q = hctx->queue; + struct adios_data *ad = q->elevator->elevator_data; + + lockdep_assert_held(&ad->lock); + + if (insert_flags & BLK_MQ_INSERT_AT_HEAD) { + spin_lock_irqsave(&ad->pq_lock, flags); + list_add(&rq->queuelist, &ad->prio_queue); + spin_unlock_irqrestore(&ad->pq_lock, flags); + return; + } + + if (blk_mq_sched_try_insert_merge(q, rq, free)) + return; + + add_to_dl_tree(ad, dl_idx, rq); + + if (rq_mergeable(rq)) { + elv_rqhash_add(q, rq); + if (!q->last_merge) + q->last_merge = rq; + } +} + +// Insert multiple requests into the scheduler +static void adios_insert_requests(struct blk_mq_hw_ctx *hctx, + struct list_head *list, + blk_insert_t insert_flags) { + unsigned long flags; + struct request_queue *q = hctx->queue; + struct adios_data *ad = q->elevator->elevator_data; + LIST_HEAD(free); + + spin_lock_irqsave(&ad->lock, flags); + while (!list_empty(list)) { + struct request *rq; + + rq = list_first_entry(list, struct request, queuelist); + list_del_init(&rq->queuelist); + insert_request(hctx, rq, insert_flags, &free); + } + spin_unlock_irqrestore(&ad->lock, flags); + + blk_mq_free_requests(&free); +} + +// Prepare a request before it is inserted into the scheduler +static void adios_prepare_request(struct request *rq) { + struct adios_data *ad = rq->q->elevator->elevator_data; + struct adios_rq_data *rd; + + rq->elv.priv[0] = NULL; + + /* Allocate adios_rq_data from the memory pool */ + rd = kmem_cache_zalloc(ad->rq_data_pool, GFP_ATOMIC); + if (WARN(!rd, "adios_prepare_request: " + "Failed to allocate memory from rq_data_pool. rd is NULL\n")) + return; + + rd->rq = rq; + rq->elv.priv[0] = rd; +} + +static struct adios_rq_data *get_dl_first_rd(struct adios_data *ad, bool idx) { + struct rb_root_cached *root = &ad->dl_tree[idx]; + struct rb_node *first = rb_first_cached(root); + struct dl_group *dl_group = rb_entry(first, struct dl_group, node); + + return list_first_entry(&dl_group->rqs, struct adios_rq_data, dl_node); +} + +// Select the next request to dispatch from the deadline-sorted red-black tree +static struct request *next_request(struct adios_data *ad) { + struct adios_rq_data *rd; + bool dl_idx, bias_idx, reduce_bias; + + if (!ad->dl_queued) + return NULL; + + dl_idx = ad->dl_queued >> 1; + rd = get_dl_first_rd(ad, dl_idx); + + bias_idx = ad->dl_bias < 0; + reduce_bias = (bias_idx == dl_idx); + + if (ad->dl_queued == 0x3) { + struct adios_rq_data *trd[2]; + trd[0] = get_dl_first_rd(ad, 0); + trd[1] = rd; + + rd = trd[bias_idx]; + + reduce_bias = + (trd[bias_idx]->deadline > trd[((u8)bias_idx + 1) % 2]->deadline); + } + + if (reduce_bias) { + s64 sign = ((int)bias_idx << 1) - 1; + if (unlikely(!rd->pred_lat)) + ad->dl_bias = sign; + else { + ad->dl_bias += sign * (s64)((rd->pred_lat * + adios_prio_to_weight[ad->dl_prio[bias_idx] + 20]) >> 10); + } + } + + return rd->rq; +} + +// Reset the batch queue counts for a given page +static void reset_batch_counts(struct adios_data *ad, u8 page) { + memset(&ad->batch_count[page], 0, sizeof(ad->batch_count[page])); +} + +// Initialize all batch queues +static void init_batch_queues(struct adios_data *ad) { + for (u8 page = 0; page < ADIOS_BQ_PAGES; page++) { + reset_batch_counts(ad, page); + + for (u8 optype = 0; optype < ADIOS_OPTYPES; optype++) + INIT_LIST_HEAD(&ad->batch_queue[page][optype]); + } +} + +// Fill the batch queues with requests from the deadline-sorted red-black tree +static bool fill_batch_queues(struct adios_data *ad, u64 current_lat) { + unsigned long flags; + u32 count = 0; + u32 optype_count[ADIOS_OPTYPES] = {0}; + u8 page = (ad->bq_page + 1) % ADIOS_BQ_PAGES; + + reset_batch_counts(ad, page); + + spin_lock_irqsave(&ad->lock, flags); + while (true) { + struct request *rq = next_request(ad); + if (!rq) + break; + + struct adios_rq_data *rd = get_rq_data(rq); + u8 optype = adios_optype(rq); + current_lat += rd->pred_lat; + + // Check batch size and total predicted latency + if (count && (!ad->latency_model[optype].base || + ad->batch_count[page][optype] >= ad->batch_limit[optype] || + current_lat > ad->global_latency_window)) { + break; + } + + remove_request(ad, rq); + + // Add request to the corresponding batch queue + list_add_tail(&rq->queuelist, &ad->batch_queue[page][optype]); + ad->batch_count[page][optype]++; + atomic64_add(rd->pred_lat, &ad->total_pred_lat); + optype_count[optype]++; + count++; + } + spin_unlock_irqrestore(&ad->lock, flags); + + if (count) { + ad->more_bq_ready = true; + for (u8 optype = 0; optype < ADIOS_OPTYPES; optype++) { + if (ad->batch_actual_max_size[optype] < optype_count[optype]) + ad->batch_actual_max_size[optype] = optype_count[optype]; + } + if (ad->batch_actual_max_total < count) + ad->batch_actual_max_total = count; + } + return count; +} + +// Flip to the next batch queue page +static void flip_bq_page(struct adios_data *ad) { + ad->more_bq_ready = false; + ad->bq_page = (ad->bq_page + 1) % ADIOS_BQ_PAGES; +} + +// Dispatch a request from the batch queues +static struct request *dispatch_from_bq(struct adios_data *ad) { + struct request *rq = NULL; + u64 tpl; + + guard(spinlock_irqsave)(&ad->bq_lock); + + tpl = atomic64_read(&ad->total_pred_lat); + + if (!ad->more_bq_ready && (!tpl || + tpl < ad->global_latency_window * ad->bq_refill_below_ratio / 100)) + fill_batch_queues(ad, tpl); + +again: + // Check if there are any requests in the batch queues + for (u8 i = 0; i < ADIOS_OPTYPES; i++) { + if (!list_empty(&ad->batch_queue[ad->bq_page][i])) { + rq = list_first_entry(&ad->batch_queue[ad->bq_page][i], + struct request, queuelist); + list_del_init(&rq->queuelist); + return rq; + } + } + + // If there's more batch queue page available, flip to it and retry + if (ad->more_bq_ready) { + flip_bq_page(ad); + goto again; + } + + return NULL; +} + +// Dispatch a request from the priority queue +static struct request *dispatch_from_pq(struct adios_data *ad) { + struct request *rq = NULL; + + guard(spinlock_irqsave)(&ad->pq_lock); + + if (!list_empty(&ad->prio_queue)) { + rq = list_first_entry(&ad->prio_queue, struct request, queuelist); + list_del_init(&rq->queuelist); + } + return rq; +} + +// Dispatch a request to the hardware queue +static struct request *adios_dispatch_request(struct blk_mq_hw_ctx *hctx) { + struct adios_data *ad = hctx->queue->elevator->elevator_data; + struct request *rq; + + rq = dispatch_from_pq(ad); + if (rq) goto found; + rq = dispatch_from_bq(ad); + if (!rq) return NULL; +found: + rq->rq_flags |= RQF_STARTED; + return rq; +} + +// Timer callback function to periodically update latency models +static void update_timer_callback(struct timer_list *t) { + struct adios_data *ad = from_timer(ad, t, update_timer); + + for (u8 optype = 0; optype < ADIOS_OPTYPES; optype++) + latency_model_update(&ad->latency_model[optype]); +} + +// Handle the completion of a request +static void adios_completed_request(struct request *rq, u64 now) { + struct adios_data *ad = rq->q->elevator->elevator_data; + struct adios_rq_data *rd = get_rq_data(rq); + + atomic64_sub(rd->pred_lat, &ad->total_pred_lat); + + if (!rq->io_start_time_ns || !rd->block_size) + return; + u64 latency = now - rq->io_start_time_ns; + u8 optype = adios_optype(rq); + latency_model_input(&ad->latency_model[optype], + rd->block_size, latency, rd->pred_lat); + timer_reduce(&ad->update_timer, jiffies + msecs_to_jiffies(100)); +} + +// Clean up after a request is finished +static void adios_finish_request(struct request *rq) { + struct adios_data *ad = rq->q->elevator->elevator_data; + + if (rq->elv.priv[0]) { + // Free adios_rq_data back to the memory pool + kmem_cache_free(ad->rq_data_pool, get_rq_data(rq)); + rq->elv.priv[0] = NULL; + } +} + +static inline bool pq_has_work(struct adios_data *ad) { + guard(spinlock_irqsave)(&ad->pq_lock); + return !list_empty(&ad->prio_queue); +} + +static inline bool bq_has_work(struct adios_data *ad) { + guard(spinlock_irqsave)(&ad->bq_lock); + + for (u8 i = 0; i < ADIOS_OPTYPES; i++) + if (!list_empty(&ad->batch_queue[ad->bq_page][i])) + return true; + + return ad->more_bq_ready; +} + +static inline bool dl_tree_has_work(struct adios_data *ad) { + guard(spinlock_irqsave)(&ad->lock); + return ad->dl_queued; +} + +// Check if there are any requests available for dispatch +static bool adios_has_work(struct blk_mq_hw_ctx *hctx) { + struct adios_data *ad = hctx->queue->elevator->elevator_data; + + return pq_has_work(ad) || bq_has_work(ad) || dl_tree_has_work(ad); +} + +// Initialize the scheduler-specific data for a hardware queue +static int adios_init_hctx(struct blk_mq_hw_ctx *hctx, unsigned int hctx_idx) { + adios_depth_updated(hctx); + return 0; +} + +// Initialize the scheduler-specific data when initializing the request queue +static int adios_init_sched(struct request_queue *q, struct elevator_type *e) { + struct adios_data *ad; + struct elevator_queue *eq; + int ret = -ENOMEM; + + eq = elevator_alloc(q, e); + if (!eq) + return ret; + + ad = kzalloc_node(sizeof(*ad), GFP_KERNEL, q->node); + if (!ad) + goto put_eq; + + // Create a memory pool for adios_rq_data + ad->rq_data_pool = kmem_cache_create("rq_data_pool", + sizeof(struct adios_rq_data), + 0, SLAB_HWCACHE_ALIGN, NULL); + if (!ad->rq_data_pool) { + pr_err("adios: Failed to create rq_data_pool\n"); + goto free_ad; + } + + /* Create a memory pool for dl_group */ + ad->dl_group_pool = kmem_cache_create("dl_group_pool", + sizeof(struct dl_group), + 0, SLAB_HWCACHE_ALIGN, NULL); + if (!ad->dl_group_pool) { + pr_err("adios: Failed to create dl_group_pool\n"); + goto destroy_rq_data_pool; + } + + eq->elevator_data = ad; + + ad->global_latency_window = default_global_latency_window; + ad->bq_refill_below_ratio = default_bq_refill_below_ratio; + + INIT_LIST_HEAD(&ad->prio_queue); + for (u8 i = 0; i < 2; i++) + ad->dl_tree[i] = RB_ROOT_CACHED; + ad->dl_bias = 0; + ad->dl_queued = 0x0; + for (u8 i = 0; i < 2; i++) + ad->dl_prio[i] = default_dl_prio[i]; + + for (u8 i = 0; i < ADIOS_OPTYPES; i++) { + struct latency_model *model = &ad->latency_model[i]; + spin_lock_init(&model->lock); + spin_lock_init(&model->buckets_lock); + memset(model->small_bucket, 0, + sizeof(model->small_bucket[0]) * LM_LAT_BUCKET_COUNT); + memset(model->large_bucket, 0, + sizeof(model->large_bucket[0]) * LM_LAT_BUCKET_COUNT); + model->last_update_jiffies = jiffies; + model->lm_shrink_at_kreqs = default_lm_shrink_at_kreqs; + model->lm_shrink_at_gbytes = default_lm_shrink_at_gbytes; + model->lm_shrink_resist = default_lm_shrink_resist; + + ad->latency_target[i] = default_latency_target[i]; + ad->batch_limit[i] = default_batch_limit[i]; + } + timer_setup(&ad->update_timer, update_timer_callback, 0); + init_batch_queues(ad); + + spin_lock_init(&ad->lock); + spin_lock_init(&ad->pq_lock); + spin_lock_init(&ad->bq_lock); + + /* We dispatch from request queue wide instead of hw queue */ + blk_queue_flag_set(QUEUE_FLAG_SQ_SCHED, q); + + q->elevator = eq; + return 0; + +destroy_rq_data_pool: + kmem_cache_destroy(ad->rq_data_pool); +free_ad: + kfree(ad); +put_eq: + kobject_put(&eq->kobj); + return ret; +} + +// Clean up and free resources when exiting the scheduler +static void adios_exit_sched(struct elevator_queue *e) { + struct adios_data *ad = e->elevator_data; + + timer_shutdown_sync(&ad->update_timer); + + WARN_ON_ONCE(!list_empty(&ad->prio_queue)); + + if (ad->rq_data_pool) + kmem_cache_destroy(ad->rq_data_pool); + + if (ad->dl_group_pool) + kmem_cache_destroy(ad->dl_group_pool); + + kfree(ad); +} + +// Define sysfs attributes for read operation latency model +#define SYSFS_OPTYPE_DECL(name, optype) \ +static ssize_t adios_lat_model_##name##_show( \ + struct elevator_queue *e, char *page) { \ + struct adios_data *ad = e->elevator_data; \ + struct latency_model *model = &ad->latency_model[optype]; \ + ssize_t len = 0; \ + guard(spinlock_irqsave)(&model->lock); \ + len += sprintf(page, "base : %llu ns\n", model->base); \ + len += sprintf(page + len, "slope: %llu ns/KiB\n", model->slope);\ + return len; \ +} \ +static ssize_t adios_lat_target_##name##_store( \ + struct elevator_queue *e, const char *page, size_t count) { \ + struct adios_data *ad = e->elevator_data; \ + unsigned long nsec; \ + int ret; \ + ret = kstrtoul(page, 10, &nsec); \ + if (ret) \ + return ret; \ + ad->latency_model[optype].base = 0ULL; \ + ad->latency_target[optype] = nsec; \ + return count; \ +} \ +static ssize_t adios_lat_target_##name##_show( \ + struct elevator_queue *e, char *page) { \ + struct adios_data *ad = e->elevator_data; \ + return sprintf(page, "%llu\n", ad->latency_target[optype]); \ +} \ +static ssize_t adios_batch_limit_##name##_store( \ + struct elevator_queue *e, const char *page, size_t count) { \ + unsigned long max_batch; \ + int ret; \ + ret = kstrtoul(page, 10, &max_batch); \ + if (ret || max_batch == 0) \ + return -EINVAL; \ + struct adios_data *ad = e->elevator_data; \ + ad->batch_limit[optype] = max_batch; \ + return count; \ +} \ +static ssize_t adios_batch_limit_##name##_show( \ + struct elevator_queue *e, char *page) { \ + struct adios_data *ad = e->elevator_data; \ + return sprintf(page, "%u\n", ad->batch_limit[optype]); \ +} + +SYSFS_OPTYPE_DECL(read, ADIOS_READ); +SYSFS_OPTYPE_DECL(write, ADIOS_WRITE); +SYSFS_OPTYPE_DECL(discard, ADIOS_DISCARD); + +// Show the maximum batch size actually achieved for each operation type +static ssize_t adios_batch_actual_max_show( + struct elevator_queue *e, char *page) { + struct adios_data *ad = e->elevator_data; + u32 total_count, read_count, write_count, discard_count; + + total_count = ad->batch_actual_max_total; + read_count = ad->batch_actual_max_size[ADIOS_READ]; + write_count = ad->batch_actual_max_size[ADIOS_WRITE]; + discard_count = ad->batch_actual_max_size[ADIOS_DISCARD]; + + return sprintf(page, + "Total : %u\nDiscard: %u\nRead : %u\nWrite : %u\n", + total_count, discard_count, read_count, write_count); +} + +// Set the global latency window +static ssize_t adios_global_latency_window_store( + struct elevator_queue *e, const char *page, size_t count) { + struct adios_data *ad = e->elevator_data; + unsigned long nsec; + int ret; + + ret = kstrtoul(page, 10, &nsec); + if (ret) + return ret; + + ad->global_latency_window = nsec; + + return count; +} + +// Show the global latency window +static ssize_t adios_global_latency_window_show( + struct elevator_queue *e, char *page) { + struct adios_data *ad = e->elevator_data; + return sprintf(page, "%llu\n", ad->global_latency_window); +} + +// Show the bq_refill_below_ratio +static ssize_t adios_bq_refill_below_ratio_show( + struct elevator_queue *e, char *page) { + struct adios_data *ad = e->elevator_data; + return sprintf(page, "%d\n", ad->bq_refill_below_ratio); +} + +// Set the bq_refill_below_ratio +static ssize_t adios_bq_refill_below_ratio_store( + struct elevator_queue *e, const char *page, size_t count) { + struct adios_data *ad = e->elevator_data; + int ratio; + int ret; + + ret = kstrtoint(page, 10, &ratio); + if (ret || ratio < 0 || ratio > 100) + return -EINVAL; + + ad->bq_refill_below_ratio = ratio; + + return count; +} + +// Show the read priority +static ssize_t adios_read_priority_show( + struct elevator_queue *e, char *page) { + struct adios_data *ad = e->elevator_data; + return sprintf(page, "%d\n", ad->dl_prio[0]); +} + +// Set the read priority +static ssize_t adios_read_priority_store( + struct elevator_queue *e, const char *page, size_t count) { + struct adios_data *ad = e->elevator_data; + int prio; + int ret; + + ret = kstrtoint(page, 10, &prio); + if (ret || prio < -20 || prio > 19) + return -EINVAL; + + guard(spinlock_irqsave)(&ad->lock); + ad->dl_prio[0] = prio; + ad->dl_bias = 0; + + return count; +} + +// Reset batch queue statistics +static ssize_t adios_reset_bq_stats_store( + struct elevator_queue *e, const char *page, size_t count) { + struct adios_data *ad = e->elevator_data; + unsigned long val; + int ret; + + ret = kstrtoul(page, 10, &val); + if (ret || val != 1) + return -EINVAL; + + for (u8 i = 0; i < ADIOS_OPTYPES; i++) + ad->batch_actual_max_size[i] = 0; + + ad->batch_actual_max_total = 0; + + return count; +} + +// Reset the latency model parameters +static ssize_t adios_reset_lat_model_store( + struct elevator_queue *e, const char *page, size_t count) { + struct adios_data *ad = e->elevator_data; + unsigned long val; + int ret; + + ret = kstrtoul(page, 10, &val); + if (ret || val != 1) + return -EINVAL; + + for (u8 i = 0; i < ADIOS_OPTYPES; i++) { + struct latency_model *model = &ad->latency_model[i]; + unsigned long flags; + spin_lock_irqsave(&model->lock, flags); + model->base = 0ULL; + model->slope = 0ULL; + model->small_sum_delay = 0ULL; + model->small_count = 0ULL; + model->large_sum_delay = 0ULL; + model->large_sum_bsize = 0ULL; + spin_unlock_irqrestore(&model->lock, flags); + } + + return count; +} + +// Show the ADIOS version +static ssize_t adios_version_show(struct elevator_queue *e, char *page) { + return sprintf(page, "%s\n", ADIOS_VERSION); +} + +// Define sysfs attributes for dynamic thresholds +#define SHRINK_THRESHOLD_ATTR_RW(name, model_field, min_value, max_value) \ +static ssize_t adios_shrink_##name##_store( \ + struct elevator_queue *e, const char *page, size_t count) { \ + struct adios_data *ad = e->elevator_data; \ + unsigned long val; \ + int ret; \ + ret = kstrtoul(page, 10, &val); \ + if (ret || val < min_value || val > max_value) \ + return -EINVAL; \ + for (u8 i = 0; i < ADIOS_OPTYPES; i++) { \ + struct latency_model *model = &ad->latency_model[i]; \ + unsigned long flags; \ + spin_lock_irqsave(&model->lock, flags); \ + model->model_field = val; \ + spin_unlock_irqrestore(&model->lock, flags); \ + } \ + return count; \ +} \ +static ssize_t adios_shrink_##name##_show( \ + struct elevator_queue *e, char *page) { \ + struct adios_data *ad = e->elevator_data; \ + u32 val = 0; \ + for (u8 i = 0; i < ADIOS_OPTYPES; i++) { \ + struct latency_model *model = &ad->latency_model[i]; \ + unsigned long flags; \ + spin_lock_irqsave(&model->lock, flags); \ + val = model->model_field; \ + spin_unlock_irqrestore(&model->lock, flags); \ + } \ + return sprintf(page, "%u\n", val); \ +} + +SHRINK_THRESHOLD_ATTR_RW(at_kreqs, lm_shrink_at_kreqs, 1, 100000) +SHRINK_THRESHOLD_ATTR_RW(at_gbytes, lm_shrink_at_gbytes, 1, 1000) +SHRINK_THRESHOLD_ATTR_RW(resist, lm_shrink_resist, 1, 3) + +// Define sysfs attributes +#define AD_ATTR(name, show_func, store_func) \ + __ATTR(name, 0644, show_func, store_func) +#define AD_ATTR_RW(name) \ + __ATTR(name, 0644, adios_##name##_show, adios_##name##_store) +#define AD_ATTR_RO(name) \ + __ATTR(name, 0644, adios_##name##_show, NULL) +#define AD_ATTR_WO(name) \ + __ATTR(name, 0644, NULL, adios_##name##_store) + +// Define sysfs attributes for ADIOS scheduler +static struct elv_fs_entry adios_sched_attrs[] = { + AD_ATTR_RO(batch_actual_max), + AD_ATTR_RW(bq_refill_below_ratio), + AD_ATTR_RW(global_latency_window), + + AD_ATTR_RW(batch_limit_read), + AD_ATTR_RW(batch_limit_write), + AD_ATTR_RW(batch_limit_discard), + + AD_ATTR_RO(lat_model_read), + AD_ATTR_RO(lat_model_write), + AD_ATTR_RO(lat_model_discard), + + AD_ATTR_RW(lat_target_read), + AD_ATTR_RW(lat_target_write), + AD_ATTR_RW(lat_target_discard), + + AD_ATTR_RW(shrink_at_kreqs), + AD_ATTR_RW(shrink_at_gbytes), + AD_ATTR_RW(shrink_resist), + + AD_ATTR_RW(read_priority), + + AD_ATTR_WO(reset_bq_stats), + AD_ATTR_WO(reset_lat_model), + AD_ATTR(adios_version, adios_version_show, NULL), + + __ATTR_NULL +}; + +// Define the ADIOS scheduler type +static struct elevator_type mq_adios = { + .ops = { + .next_request = elv_rb_latter_request, + .former_request = elv_rb_former_request, + .limit_depth = adios_limit_depth, + .depth_updated = adios_depth_updated, + .request_merged = adios_request_merged, + .requests_merged = adios_merged_requests, + .bio_merge = adios_bio_merge, + .insert_requests = adios_insert_requests, + .prepare_request = adios_prepare_request, + .dispatch_request = adios_dispatch_request, + .completed_request = adios_completed_request, + .finish_request = adios_finish_request, + .has_work = adios_has_work, + .init_hctx = adios_init_hctx, + .init_sched = adios_init_sched, + .exit_sched = adios_exit_sched, + }, +#ifdef CONFIG_BLK_DEBUG_FS +#endif + .elevator_attrs = adios_sched_attrs, + .elevator_name = "adios", + .elevator_owner = THIS_MODULE, +}; +MODULE_ALIAS("mq-adios-iosched"); + +#define ADIOS_PROGNAME "Adaptive Deadline I/O Scheduler" +#define ADIOS_AUTHOR "Masahito Suzuki" + +// Initialize the ADIOS scheduler module +static int __init adios_init(void) { + printk(KERN_INFO "%s %s by %s\n", + ADIOS_PROGNAME, ADIOS_VERSION, ADIOS_AUTHOR); + return elv_register(&mq_adios); +} + +// Exit the ADIOS scheduler module +static void __exit adios_exit(void) { + elv_unregister(&mq_adios); +} + +module_init(adios_init); +module_exit(adios_exit); + +MODULE_AUTHOR(ADIOS_AUTHOR); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION(ADIOS_PROGNAME); \ No newline at end of file diff --git a/block/elevator.c b/block/elevator.c index cd2ce4921601..bcd6d2482a8a 100644 --- a/block/elevator.c +++ b/block/elevator.c @@ -558,9 +558,17 @@ static struct elevator_type *elevator_get_default(struct request_queue *q) if (q->nr_hw_queues != 1 && !blk_mq_is_shared_tags(q->tag_set->flags)) +#if defined(CONFIG_CACHY) + return elevator_find_get("mq-deadline"); +#else return NULL; +#endif +#if defined(CONFIG_CACHY) && defined(CONFIG_IOSCHED_BFQ) + return elevator_find_get("bfq"); +#else return elevator_find_get("mq-deadline"); +#endif } /* diff --git a/drivers/Makefile b/drivers/Makefile index 45d1c3e630f7..4f5ab2429a7f 100644 --- a/drivers/Makefile +++ b/drivers/Makefile @@ -64,14 +64,8 @@ obj-y += char/ # iommu/ comes before gpu as gpu are using iommu controllers obj-y += iommu/ -# gpu/ comes after char for AGP vs DRM startup and after iommu -obj-y += gpu/ - obj-$(CONFIG_CONNECTOR) += connector/ -# i810fb depends on char/agp/ -obj-$(CONFIG_FB_I810) += video/fbdev/i810/ - obj-$(CONFIG_PARPORT) += parport/ obj-y += base/ block/ misc/ mfd/ nfc/ obj-$(CONFIG_LIBNVDIMM) += nvdimm/ @@ -83,6 +77,13 @@ obj-y += macintosh/ obj-y += scsi/ obj-y += nvme/ obj-$(CONFIG_ATA) += ata/ + +# gpu/ comes after char for AGP vs DRM startup and after iommu +obj-y += gpu/ + +# i810fb depends on char/agp/ +obj-$(CONFIG_FB_I810) += video/fbdev/i810/ + obj-$(CONFIG_TARGET_CORE) += target/ obj-$(CONFIG_MTD) += mtd/ obj-$(CONFIG_SPI) += spi/ diff --git a/drivers/ata/ahci.c b/drivers/ata/ahci.c index f3a6bfe098cd..e926c78380a2 100644 --- a/drivers/ata/ahci.c +++ b/drivers/ata/ahci.c @@ -1629,7 +1629,7 @@ static irqreturn_t ahci_thunderx_irq_handler(int irq, void *dev_instance) } #endif -static void ahci_remap_check(struct pci_dev *pdev, int bar, +static int ahci_remap_check(struct pci_dev *pdev, int bar, struct ahci_host_priv *hpriv) { int i; @@ -1642,7 +1642,7 @@ static void ahci_remap_check(struct pci_dev *pdev, int bar, pci_resource_len(pdev, bar) < SZ_512K || bar != AHCI_PCI_BAR_STANDARD || !(readl(hpriv->mmio + AHCI_VSCAP) & 1)) - return; + return 0; cap = readq(hpriv->mmio + AHCI_REMAP_CAP); for (i = 0; i < AHCI_MAX_REMAP; i++) { @@ -1657,18 +1657,11 @@ static void ahci_remap_check(struct pci_dev *pdev, int bar, } if (!hpriv->remapped_nvme) - return; - - dev_warn(&pdev->dev, "Found %u remapped NVMe devices.\n", - hpriv->remapped_nvme); - dev_warn(&pdev->dev, - "Switch your BIOS from RAID to AHCI mode to use them.\n"); + return 0; - /* - * Don't rely on the msi-x capability in the remap case, - * share the legacy interrupt across ahci and remapped devices. - */ - hpriv->flags |= AHCI_HFLAG_NO_MSI; + /* Abort probe, allowing intel-nvme-remap to step in when available */ + dev_info(&pdev->dev, "Device will be handled by intel-nvme-remap.\n"); + return -ENODEV; } static int ahci_get_irq_vector(struct ata_host *host, int port) @@ -1909,7 +1902,9 @@ static int ahci_init_one(struct pci_dev *pdev, const struct pci_device_id *ent) return -ENOMEM; /* detect remapped nvme devices */ - ahci_remap_check(pdev, ahci_pci_bar, hpriv); + rc = ahci_remap_check(pdev, ahci_pci_bar, hpriv); + if (rc) + return rc; sysfs_add_file_to_group(&pdev->dev.kobj, &dev_attr_remapped_nvme.attr, diff --git a/drivers/cpufreq/Kconfig.x86 b/drivers/cpufreq/Kconfig.x86 index 97c2d4f15d76..5a3af44d785a 100644 --- a/drivers/cpufreq/Kconfig.x86 +++ b/drivers/cpufreq/Kconfig.x86 @@ -9,7 +9,6 @@ config X86_INTEL_PSTATE select ACPI_PROCESSOR if ACPI select ACPI_CPPC_LIB if X86_64 && ACPI && SCHED_MC_PRIO select CPU_FREQ_GOV_PERFORMANCE - select CPU_FREQ_GOV_SCHEDUTIL if SMP help This driver provides a P state for Intel core processors. The driver implements an internal governor and will become @@ -39,7 +38,6 @@ config X86_AMD_PSTATE depends on X86 && ACPI select ACPI_PROCESSOR select ACPI_CPPC_LIB if X86_64 - select CPU_FREQ_GOV_SCHEDUTIL if SMP help This driver adds a CPUFreq driver which utilizes a fine grain processor performance frequency control range instead of legacy diff --git a/drivers/cpufreq/intel_pstate.c b/drivers/cpufreq/intel_pstate.c index 9c4cc01fd51a..06a3c53e116e 100644 --- a/drivers/cpufreq/intel_pstate.c +++ b/drivers/cpufreq/intel_pstate.c @@ -3827,6 +3827,8 @@ static int __init intel_pstate_setup(char *str) if (!strcmp(str, "disable")) no_load = 1; + else if (!strcmp(str, "enable")) + no_load = 0; else if (!strcmp(str, "active")) default_driver = &intel_pstate; else if (!strcmp(str, "passive")) diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu.h b/drivers/gpu/drm/amd/amdgpu/amdgpu.h index 69895fccb474..5884de0f23bb 100644 --- a/drivers/gpu/drm/amd/amdgpu/amdgpu.h +++ b/drivers/gpu/drm/amd/amdgpu/amdgpu.h @@ -160,6 +160,7 @@ struct amdgpu_watchdog_timer { */ extern int amdgpu_modeset; extern unsigned int amdgpu_vram_limit; +extern int amdgpu_ignore_min_pcap; extern int amdgpu_vis_vram_limit; extern int amdgpu_gart_size; extern int amdgpu_gtt_size; diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_atombios.c b/drivers/gpu/drm/amd/amdgpu/amdgpu_atombios.c index 093141ad6ed0..e476e45b996a 100644 --- a/drivers/gpu/drm/amd/amdgpu/amdgpu_atombios.c +++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_atombios.c @@ -36,13 +36,6 @@ #include "atombios_encoders.h" #include "bif/bif_4_1_d.h" -static void amdgpu_atombios_lookup_i2c_gpio_quirks(struct amdgpu_device *adev, - ATOM_GPIO_I2C_ASSIGMENT *gpio, - u8 index) -{ - -} - static struct amdgpu_i2c_bus_rec amdgpu_atombios_get_bus_rec_for_i2c_gpio(ATOM_GPIO_I2C_ASSIGMENT *gpio) { struct amdgpu_i2c_bus_rec i2c; @@ -108,9 +101,6 @@ struct amdgpu_i2c_bus_rec amdgpu_atombios_lookup_i2c_gpio(struct amdgpu_device * gpio = &i2c_info->asGPIO_Info[0]; for (i = 0; i < num_indices; i++) { - - amdgpu_atombios_lookup_i2c_gpio_quirks(adev, gpio, i); - if (gpio->sucI2cId.ucAccess == id) { i2c = amdgpu_atombios_get_bus_rec_for_i2c_gpio(gpio); break; @@ -142,8 +132,6 @@ void amdgpu_atombios_i2c_init(struct amdgpu_device *adev) gpio = &i2c_info->asGPIO_Info[0]; for (i = 0; i < num_indices; i++) { - amdgpu_atombios_lookup_i2c_gpio_quirks(adev, gpio, i); - i2c = amdgpu_atombios_get_bus_rec_for_i2c_gpio(gpio); if (i2c.valid) { @@ -156,6 +144,38 @@ void amdgpu_atombios_i2c_init(struct amdgpu_device *adev) } } +void amdgpu_atombios_oem_i2c_init(struct amdgpu_device *adev, u8 i2c_id) +{ + struct atom_context *ctx = adev->mode_info.atom_context; + ATOM_GPIO_I2C_ASSIGMENT *gpio; + struct amdgpu_i2c_bus_rec i2c; + int index = GetIndexIntoMasterTable(DATA, GPIO_I2C_Info); + struct _ATOM_GPIO_I2C_INFO *i2c_info; + uint16_t data_offset, size; + int i, num_indices; + char stmp[32]; + + if (amdgpu_atom_parse_data_header(ctx, index, &size, NULL, NULL, &data_offset)) { + i2c_info = (struct _ATOM_GPIO_I2C_INFO *)(ctx->bios + data_offset); + + num_indices = (size - sizeof(ATOM_COMMON_TABLE_HEADER)) / + sizeof(ATOM_GPIO_I2C_ASSIGMENT); + + gpio = &i2c_info->asGPIO_Info[0]; + for (i = 0; i < num_indices; i++) { + i2c = amdgpu_atombios_get_bus_rec_for_i2c_gpio(gpio); + + if (i2c.valid && i2c.i2c_id == i2c_id) { + sprintf(stmp, "OEM 0x%x", i2c.i2c_id); + adev->i2c_bus[i] = amdgpu_i2c_create(adev_to_drm(adev), &i2c, stmp); + break; + } + gpio = (ATOM_GPIO_I2C_ASSIGMENT *) + ((u8 *)gpio + sizeof(ATOM_GPIO_I2C_ASSIGMENT)); + } + } +} + struct amdgpu_gpio_rec amdgpu_atombios_lookup_gpio(struct amdgpu_device *adev, u8 id) diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_atombios.h b/drivers/gpu/drm/amd/amdgpu/amdgpu_atombios.h index 0e16432d9a72..867bc5c5ce67 100644 --- a/drivers/gpu/drm/amd/amdgpu/amdgpu_atombios.h +++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_atombios.h @@ -136,6 +136,7 @@ amdgpu_atombios_lookup_gpio(struct amdgpu_device *adev, struct amdgpu_i2c_bus_rec amdgpu_atombios_lookup_i2c_gpio(struct amdgpu_device *adev, uint8_t id); void amdgpu_atombios_i2c_init(struct amdgpu_device *adev); +void amdgpu_atombios_oem_i2c_init(struct amdgpu_device *adev, u8 i2c_id); bool amdgpu_atombios_has_dce_engine_info(struct amdgpu_device *adev); diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_device.c b/drivers/gpu/drm/amd/amdgpu/amdgpu_device.c index 9a8f6cb2b836..b888eaaaec23 100644 --- a/drivers/gpu/drm/amd/amdgpu/amdgpu_device.c +++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_device.c @@ -4472,8 +4472,7 @@ int amdgpu_device_init(struct amdgpu_device *adev, goto failed; } /* init i2c buses */ - if (!amdgpu_device_has_dc_support(adev)) - amdgpu_atombios_i2c_init(adev); + amdgpu_i2c_init(adev); } } @@ -4742,8 +4741,7 @@ void amdgpu_device_fini_sw(struct amdgpu_device *adev) amdgpu_reset_fini(adev); /* free i2c buses */ - if (!amdgpu_device_has_dc_support(adev)) - amdgpu_i2c_fini(adev); + amdgpu_i2c_fini(adev); if (amdgpu_emu_mode != 1) amdgpu_atombios_fini(adev); diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c b/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c index c0ddbe7d6f0b..73b1a742c5e4 100644 --- a/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c +++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c @@ -139,6 +139,7 @@ enum AMDGPU_DEBUG_MASK { }; unsigned int amdgpu_vram_limit = UINT_MAX; +int amdgpu_ignore_min_pcap = 0; /* do not ignore by default */ int amdgpu_vis_vram_limit; int amdgpu_gart_size = -1; /* auto */ int amdgpu_gtt_size = -1; /* auto */ @@ -257,6 +258,15 @@ struct amdgpu_watchdog_timer amdgpu_watchdog_timer = { .period = 0x0, /* default to 0x0 (timeout disable) */ }; +/** + * DOC: ignore_min_pcap (int) + * Ignore the minimum power cap. + * Useful on graphics cards where the minimum power cap is very high. + * The default is 0 (Do not ignore). + */ +MODULE_PARM_DESC(ignore_min_pcap, "Ignore the minimum power cap"); +module_param_named(ignore_min_pcap, amdgpu_ignore_min_pcap, int, 0600); + /** * DOC: vramlimit (int) * Restrict the total amount of VRAM in MiB for testing. The default is 0 (Use full VRAM). diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_i2c.c b/drivers/gpu/drm/amd/amdgpu/amdgpu_i2c.c index f0765ccde668..8179d0814db9 100644 --- a/drivers/gpu/drm/amd/amdgpu/amdgpu_i2c.c +++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_i2c.c @@ -225,6 +225,25 @@ void amdgpu_i2c_destroy(struct amdgpu_i2c_chan *i2c) kfree(i2c); } +void amdgpu_i2c_init(struct amdgpu_device *adev) +{ + if (!adev->is_atom_fw) { + if (!amdgpu_device_has_dc_support(adev)) { + amdgpu_atombios_i2c_init(adev); + } else { + switch (adev->asic_type) { + case CHIP_POLARIS10: + case CHIP_POLARIS11: + case CHIP_POLARIS12: + amdgpu_atombios_oem_i2c_init(adev, 0x97); + break; + default: + break; + } + } + } +} + /* remove all the buses */ void amdgpu_i2c_fini(struct amdgpu_device *adev) { diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_i2c.h b/drivers/gpu/drm/amd/amdgpu/amdgpu_i2c.h index 21e3d1dad0a1..1d3d3806e0dd 100644 --- a/drivers/gpu/drm/amd/amdgpu/amdgpu_i2c.h +++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_i2c.h @@ -28,6 +28,7 @@ struct amdgpu_i2c_chan *amdgpu_i2c_create(struct drm_device *dev, const struct amdgpu_i2c_bus_rec *rec, const char *name); void amdgpu_i2c_destroy(struct amdgpu_i2c_chan *i2c); +void amdgpu_i2c_init(struct amdgpu_device *adev); void amdgpu_i2c_fini(struct amdgpu_device *adev); struct amdgpu_i2c_chan * amdgpu_i2c_lookup(struct amdgpu_device *adev, diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_mode.h b/drivers/gpu/drm/amd/amdgpu/amdgpu_mode.h index 5e3faefc5510..6da4f946cac0 100644 --- a/drivers/gpu/drm/amd/amdgpu/amdgpu_mode.h +++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_mode.h @@ -609,6 +609,7 @@ struct amdgpu_i2c_adapter { struct i2c_adapter base; struct ddc_service *ddc_service; + bool oem; }; #define TO_DM_AUX(x) container_of((x), struct amdgpu_dm_dp_aux, aux) diff --git a/drivers/gpu/drm/amd/display/Kconfig b/drivers/gpu/drm/amd/display/Kconfig index abd3b6564373..46937e6fa78d 100644 --- a/drivers/gpu/drm/amd/display/Kconfig +++ b/drivers/gpu/drm/amd/display/Kconfig @@ -56,4 +56,10 @@ config DRM_AMD_SECURE_DISPLAY This option enables the calculation of crc of specific region via debugfs. Cooperate with specific DMCU FW. +config AMD_PRIVATE_COLOR + bool "Enable KMS color management by AMD for AMD" + default n + help + This option extends the KMS color management API with AMD driver-specific properties to enhance the color management support on AMD Steam Deck. + endmenu diff --git a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm.c b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm.c index 39df45f652b3..6d99dc4013c2 100644 --- a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm.c +++ b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm.c @@ -179,6 +179,8 @@ static int amdgpu_dm_init(struct amdgpu_device *adev); static void amdgpu_dm_fini(struct amdgpu_device *adev); static bool is_freesync_video_mode(const struct drm_display_mode *mode, struct amdgpu_dm_connector *aconnector); static void reset_freesync_config_for_crtc(struct dm_crtc_state *new_crtc_state); +static struct amdgpu_i2c_adapter * +create_i2c(struct ddc_service *ddc_service, bool oem); static enum drm_mode_subconnector get_subconnector_type(struct dc_link *link) { @@ -2952,6 +2954,33 @@ static int amdgpu_dm_smu_write_watermarks_table(struct amdgpu_device *adev) return 0; } +static int dm_oem_i2c_hw_init(struct amdgpu_device *adev) +{ + struct amdgpu_display_manager *dm = &adev->dm; + struct amdgpu_i2c_adapter *oem_i2c; + struct ddc_service *oem_ddc_service; + int r; + + oem_ddc_service = dc_get_oem_i2c_device(adev->dm.dc); + if (oem_ddc_service) { + oem_i2c = create_i2c(oem_ddc_service, true); + if (!oem_i2c) { + dev_info(adev->dev, "Failed to create oem i2c adapter data\n"); + return -ENOMEM; + } + + r = i2c_add_adapter(&oem_i2c->base); + if (r) { + dev_info(adev->dev, "Failed to register oem i2c\n"); + kfree(oem_i2c); + return r; + } + dm->oem_i2c = oem_i2c; + } + + return 0; +} + /** * dm_hw_init() - Initialize DC device * @ip_block: Pointer to the amdgpu_ip_block for this hw instance. @@ -2983,6 +3012,10 @@ static int dm_hw_init(struct amdgpu_ip_block *ip_block) return r; amdgpu_dm_hpd_init(adev); + r = dm_oem_i2c_hw_init(adev); + if (r) + dev_info(adev->dev, "Failed to add OEM i2c bus\n"); + return 0; } @@ -2998,6 +3031,8 @@ static int dm_hw_fini(struct amdgpu_ip_block *ip_block) { struct amdgpu_device *adev = ip_block->adev; + kfree(adev->dm.oem_i2c); + amdgpu_dm_hpd_fini(adev); amdgpu_dm_irq_fini(adev); @@ -4652,7 +4687,7 @@ static int amdgpu_dm_mode_config_init(struct amdgpu_device *adev) return r; } -#ifdef AMD_PRIVATE_COLOR +#ifdef CONFIG_AMD_PRIVATE_COLOR if (amdgpu_dm_create_color_properties(adev)) { dc_state_release(state->context); kfree(state); @@ -8365,7 +8400,7 @@ static int amdgpu_dm_i2c_xfer(struct i2c_adapter *i2c_adap, int i; int result = -EIO; - if (!ddc_service->ddc_pin || !ddc_service->ddc_pin->hw_info.hw_supported) + if (!ddc_service->ddc_pin) return result; cmd.payloads = kcalloc(num, sizeof(struct i2c_payload), GFP_KERNEL); @@ -8384,11 +8419,18 @@ static int amdgpu_dm_i2c_xfer(struct i2c_adapter *i2c_adap, cmd.payloads[i].data = msgs[i].buf; } - if (dc_submit_i2c( - ddc_service->ctx->dc, - ddc_service->link->link_index, - &cmd)) - result = num; + if (i2c->oem) { + if (dc_submit_i2c_oem( + ddc_service->ctx->dc, + &cmd)) + result = num; + } else { + if (dc_submit_i2c( + ddc_service->ctx->dc, + ddc_service->link->link_index, + &cmd)) + result = num; + } kfree(cmd.payloads); return result; @@ -8405,9 +8447,7 @@ static const struct i2c_algorithm amdgpu_dm_i2c_algo = { }; static struct amdgpu_i2c_adapter * -create_i2c(struct ddc_service *ddc_service, - int link_index, - int *res) +create_i2c(struct ddc_service *ddc_service, bool oem) { struct amdgpu_device *adev = ddc_service->ctx->driver_context; struct amdgpu_i2c_adapter *i2c; @@ -8418,9 +8458,14 @@ create_i2c(struct ddc_service *ddc_service, i2c->base.owner = THIS_MODULE; i2c->base.dev.parent = &adev->pdev->dev; i2c->base.algo = &amdgpu_dm_i2c_algo; - snprintf(i2c->base.name, sizeof(i2c->base.name), "AMDGPU DM i2c hw bus %d", link_index); + if (oem) + snprintf(i2c->base.name, sizeof(i2c->base.name), "AMDGPU DM i2c OEM bus"); + else + snprintf(i2c->base.name, sizeof(i2c->base.name), "AMDGPU DM i2c hw bus %d", + ddc_service->link->link_index); i2c_set_adapdata(&i2c->base, i2c); i2c->ddc_service = ddc_service; + i2c->oem = oem; return i2c; } @@ -8466,7 +8511,7 @@ static int amdgpu_dm_connector_init(struct amdgpu_display_manager *dm, link->priv = aconnector; - i2c = create_i2c(link->ddc, link->link_index, &res); + i2c = create_i2c(link->ddc, false); if (!i2c) { DRM_ERROR("Failed to create i2c adapter data\n"); return -ENOMEM; diff --git a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm.h b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm.h index d2703ca7dff3..ef60e80de19c 100644 --- a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm.h +++ b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm.h @@ -606,6 +606,13 @@ struct amdgpu_display_manager { * Bounding box data read from dmub during early initialization for DCN4+ */ struct dml2_soc_bb *bb_from_dmub; + + /** + * @oem_i2c: + * + * OEM i2c bus + */ + struct amdgpu_i2c_adapter *oem_i2c; }; enum dsc_clock_force_state { diff --git a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_color.c b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_color.c index ebabfe3a512f..4d3ebcaacca1 100644 --- a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_color.c +++ b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_color.c @@ -97,7 +97,7 @@ static inline struct fixed31_32 amdgpu_dm_fixpt_from_s3132(__u64 x) return val; } -#ifdef AMD_PRIVATE_COLOR +#ifdef CONFIG_AMD_PRIVATE_COLOR /* Pre-defined Transfer Functions (TF) * * AMD driver supports pre-defined mathematical functions for transferring diff --git a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_crtc.c b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_crtc.c index 36a830a7440f..a8fc8bd52d51 100644 --- a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_crtc.c +++ b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_crtc.c @@ -470,7 +470,7 @@ static int amdgpu_dm_crtc_late_register(struct drm_crtc *crtc) } #endif -#ifdef AMD_PRIVATE_COLOR +#ifdef CONFIG_AMD_PRIVATE_COLOR /** * dm_crtc_additional_color_mgmt - enable additional color properties * @crtc: DRM CRTC @@ -552,7 +552,7 @@ static const struct drm_crtc_funcs amdgpu_dm_crtc_funcs = { #if defined(CONFIG_DEBUG_FS) .late_register = amdgpu_dm_crtc_late_register, #endif -#ifdef AMD_PRIVATE_COLOR +#ifdef CONFIG_AMD_PRIVATE_COLOR .atomic_set_property = amdgpu_dm_atomic_crtc_set_property, .atomic_get_property = amdgpu_dm_atomic_crtc_get_property, #endif @@ -731,7 +731,7 @@ int amdgpu_dm_crtc_init(struct amdgpu_display_manager *dm, drm_mode_crtc_set_gamma_size(&acrtc->base, MAX_COLOR_LEGACY_LUT_ENTRIES); -#ifdef AMD_PRIVATE_COLOR +#ifdef CONFIG_AMD_PRIVATE_COLOR dm_crtc_additional_color_mgmt(&acrtc->base); #endif return 0; diff --git a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_plane.c b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_plane.c index 92472109f84a..d5f3e6d5aa7f 100644 --- a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_plane.c +++ b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_plane.c @@ -1598,7 +1598,7 @@ static void amdgpu_dm_plane_drm_plane_destroy_state(struct drm_plane *plane, drm_atomic_helper_plane_destroy_state(plane, state); } -#ifdef AMD_PRIVATE_COLOR +#ifdef CONFIG_AMD_PRIVATE_COLOR static void dm_atomic_plane_attach_color_mgmt_properties(struct amdgpu_display_manager *dm, struct drm_plane *plane) @@ -1789,7 +1789,7 @@ static const struct drm_plane_funcs dm_plane_funcs = { .atomic_duplicate_state = amdgpu_dm_plane_drm_plane_duplicate_state, .atomic_destroy_state = amdgpu_dm_plane_drm_plane_destroy_state, .format_mod_supported = amdgpu_dm_plane_format_mod_supported, -#ifdef AMD_PRIVATE_COLOR +#ifdef CONFIG_AMD_PRIVATE_COLOR .atomic_set_property = dm_atomic_plane_set_property, .atomic_get_property = dm_atomic_plane_get_property, #endif @@ -1885,7 +1885,7 @@ int amdgpu_dm_plane_init(struct amdgpu_display_manager *dm, else drm_plane_helper_add(plane, &dm_plane_helper_funcs); -#ifdef AMD_PRIVATE_COLOR +#ifdef CONFIG_AMD_PRIVATE_COLOR dm_atomic_plane_attach_color_mgmt_properties(dm, plane); #endif /* Create (reset) the plane state */ diff --git a/drivers/gpu/drm/amd/display/dc/bios/bios_parser2.c b/drivers/gpu/drm/amd/display/dc/bios/bios_parser2.c index a62f6c51301c..1d2c6019efac 100644 --- a/drivers/gpu/drm/amd/display/dc/bios/bios_parser2.c +++ b/drivers/gpu/drm/amd/display/dc/bios/bios_parser2.c @@ -1778,6 +1778,7 @@ static enum bp_result get_firmware_info_v3_1( struct dc_firmware_info *info) { struct atom_firmware_info_v3_1 *firmware_info; + struct atom_firmware_info_v3_2 *firmware_info32; struct atom_display_controller_info_v4_1 *dce_info = NULL; if (!info) @@ -1785,6 +1786,8 @@ static enum bp_result get_firmware_info_v3_1( firmware_info = GET_IMAGE(struct atom_firmware_info_v3_1, DATA_TABLES(firmwareinfo)); + firmware_info32 = GET_IMAGE(struct atom_firmware_info_v3_2, + DATA_TABLES(firmwareinfo)); dce_info = GET_IMAGE(struct atom_display_controller_info_v4_1, DATA_TABLES(dce_info)); @@ -1817,7 +1820,15 @@ static enum bp_result get_firmware_info_v3_1( bp->cmd_tbl.get_smu_clock_info(bp, SMU9_SYSPLL0_ID) * 10; } - info->oem_i2c_present = false; + /* These fields are marked as reserved in v3_1, but they appear to be populated + * properly. + */ + if (firmware_info32->board_i2c_feature_id == 0x2) { + info->oem_i2c_present = true; + info->oem_i2c_obj_id = firmware_info32->board_i2c_feature_gpio_id; + } else { + info->oem_i2c_present = false; + } return BP_RESULT_OK; } diff --git a/drivers/gpu/drm/amd/display/dc/core/dc_link_exports.c b/drivers/gpu/drm/amd/display/dc/core/dc_link_exports.c index c1b79b379447..261c3bc4d46e 100644 --- a/drivers/gpu/drm/amd/display/dc/core/dc_link_exports.c +++ b/drivers/gpu/drm/amd/display/dc/core/dc_link_exports.c @@ -150,6 +150,12 @@ bool dc_link_update_dsc_config(struct pipe_ctx *pipe_ctx) return link->dc->link_srv->update_dsc_config(pipe_ctx); } +struct ddc_service * +dc_get_oem_i2c_device(struct dc *dc) +{ + return dc->res_pool->oem_device; +} + bool dc_is_oem_i2c_device_present( struct dc *dc, size_t slave_address) diff --git a/drivers/gpu/drm/amd/display/dc/dc.h b/drivers/gpu/drm/amd/display/dc/dc.h index ab77dcbc1058..1f9c8bb6c79d 100644 --- a/drivers/gpu/drm/amd/display/dc/dc.h +++ b/drivers/gpu/drm/amd/display/dc/dc.h @@ -1949,6 +1949,9 @@ int dc_link_aux_transfer_raw(struct ddc_service *ddc, struct aux_payload *payload, enum aux_return_code_type *operation_result); +struct ddc_service * +dc_get_oem_i2c_device(struct dc *dc); + bool dc_is_oem_i2c_device_present( struct dc *dc, size_t slave_address diff --git a/drivers/gpu/drm/amd/display/dc/resource/dce120/dce120_resource.c b/drivers/gpu/drm/amd/display/dc/resource/dce120/dce120_resource.c index c63c59623433..eb1e158d3436 100644 --- a/drivers/gpu/drm/amd/display/dc/resource/dce120/dce120_resource.c +++ b/drivers/gpu/drm/amd/display/dc/resource/dce120/dce120_resource.c @@ -67,6 +67,7 @@ #include "reg_helper.h" #include "dce100/dce100_resource.h" +#include "link.h" #ifndef mmDP0_DP_DPHY_INTERNAL_CTRL #define mmDP0_DP_DPHY_INTERNAL_CTRL 0x210f @@ -659,6 +660,12 @@ static void dce120_resource_destruct(struct dce110_resource_pool *pool) if (pool->base.dmcu != NULL) dce_dmcu_destroy(&pool->base.dmcu); + + if (pool->base.oem_device != NULL) { + struct dc *dc = pool->base.oem_device->ctx->dc; + + dc->link_srv->destroy_ddc_service(&pool->base.oem_device); + } } static void read_dce_straps( @@ -1054,6 +1061,7 @@ static bool dce120_resource_construct( struct dc *dc, struct dce110_resource_pool *pool) { + struct ddc_service_init_data ddc_init_data = {0}; unsigned int i; int j; struct dc_context *ctx = dc->ctx; @@ -1257,6 +1265,15 @@ static bool dce120_resource_construct( bw_calcs_data_update_from_pplib(dc); + if (dc->ctx->dc_bios->fw_info.oem_i2c_present) { + ddc_init_data.ctx = dc->ctx; + ddc_init_data.link = NULL; + ddc_init_data.id.id = dc->ctx->dc_bios->fw_info.oem_i2c_obj_id; + ddc_init_data.id.enum_id = 0; + ddc_init_data.id.type = OBJECT_TYPE_GENERIC; + pool->base.oem_device = dc->link_srv->create_ddc_service(&ddc_init_data); + } + return true; irqs_create_fail: diff --git a/drivers/gpu/drm/amd/pm/amdgpu_pm.c b/drivers/gpu/drm/amd/pm/amdgpu_pm.c index 77b1f061bbf0..89ffc9fd9463 100644 --- a/drivers/gpu/drm/amd/pm/amdgpu_pm.c +++ b/drivers/gpu/drm/amd/pm/amdgpu_pm.c @@ -3180,6 +3180,9 @@ static ssize_t amdgpu_hwmon_show_power_cap_min(struct device *dev, struct device_attribute *attr, char *buf) { + if (amdgpu_ignore_min_pcap) + return sysfs_emit(buf, "%i\n", 0); + return amdgpu_hwmon_show_power_cap_generic(dev, attr, buf, PP_PWR_LIMIT_MIN); } diff --git a/drivers/gpu/drm/amd/pm/swsmu/amdgpu_smu.c b/drivers/gpu/drm/amd/pm/swsmu/amdgpu_smu.c index ed9dac00ebfb..b3fd5e9df44c 100644 --- a/drivers/gpu/drm/amd/pm/swsmu/amdgpu_smu.c +++ b/drivers/gpu/drm/amd/pm/swsmu/amdgpu_smu.c @@ -2823,7 +2823,10 @@ int smu_get_power_limit(void *handle, *limit = smu->max_power_limit; break; case SMU_PPT_LIMIT_MIN: - *limit = smu->min_power_limit; + if (amdgpu_ignore_min_pcap) + *limit = 0; + else + *limit = smu->min_power_limit; break; default: return -EINVAL; @@ -2847,7 +2850,14 @@ static int smu_set_power_limit(void *handle, uint32_t limit) if (smu->ppt_funcs->set_power_limit) return smu->ppt_funcs->set_power_limit(smu, limit_type, limit); - if ((limit > smu->max_power_limit) || (limit < smu->min_power_limit)) { + if (amdgpu_ignore_min_pcap) { + if ((limit > smu->max_power_limit)) { + dev_err(smu->adev->dev, + "New power limit (%d) is over the max allowed %d\n", + limit, smu->max_power_limit); + return -EINVAL; + } + } else if ((limit > smu->max_power_limit) || (limit < smu->min_power_limit)) { dev_err(smu->adev->dev, "New power limit (%d) is out of range [%d,%d]\n", limit, smu->min_power_limit, smu->max_power_limit); diff --git a/drivers/input/evdev.c b/drivers/input/evdev.c index b5cbb57ee5f6..a0f7fa1518c6 100644 --- a/drivers/input/evdev.c +++ b/drivers/input/evdev.c @@ -46,6 +46,7 @@ struct evdev_client { struct fasync_struct *fasync; struct evdev *evdev; struct list_head node; + struct rcu_head rcu; enum input_clock_type clk_type; bool revoked; unsigned long *evmasks[EV_CNT]; @@ -368,13 +369,22 @@ static void evdev_attach_client(struct evdev *evdev, spin_unlock(&evdev->client_lock); } +static void evdev_reclaim_client(struct rcu_head *rp) +{ + struct evdev_client *client = container_of(rp, struct evdev_client, rcu); + unsigned int i; + for (i = 0; i < EV_CNT; ++i) + bitmap_free(client->evmasks[i]); + kvfree(client); +} + static void evdev_detach_client(struct evdev *evdev, struct evdev_client *client) { spin_lock(&evdev->client_lock); list_del_rcu(&client->node); spin_unlock(&evdev->client_lock); - synchronize_rcu(); + call_rcu(&client->rcu, evdev_reclaim_client); } static int evdev_open_device(struct evdev *evdev) @@ -427,7 +437,6 @@ static int evdev_release(struct inode *inode, struct file *file) { struct evdev_client *client = file->private_data; struct evdev *evdev = client->evdev; - unsigned int i; mutex_lock(&evdev->mutex); @@ -439,11 +448,6 @@ static int evdev_release(struct inode *inode, struct file *file) evdev_detach_client(evdev, client); - for (i = 0; i < EV_CNT; ++i) - bitmap_free(client->evmasks[i]); - - kvfree(client); - evdev_close_device(evdev); return 0; @@ -486,7 +490,6 @@ static int evdev_open(struct inode *inode, struct file *file) err_free_client: evdev_detach_client(evdev, client); - kvfree(client); return error; } diff --git a/drivers/md/dm-crypt.c b/drivers/md/dm-crypt.c index 02a2919f4e5a..67fde4125238 100644 --- a/drivers/md/dm-crypt.c +++ b/drivers/md/dm-crypt.c @@ -3305,6 +3305,11 @@ static int crypt_ctr(struct dm_target *ti, unsigned int argc, char **argv) goto bad; } +#ifdef CONFIG_CACHY + set_bit(DM_CRYPT_NO_READ_WORKQUEUE, &cc->flags); + set_bit(DM_CRYPT_NO_WRITE_WORKQUEUE, &cc->flags); +#endif + ret = crypt_ctr_cipher(ti, argv[0], argv[1]); if (ret < 0) goto bad; diff --git a/drivers/media/v4l2-core/Kconfig b/drivers/media/v4l2-core/Kconfig index 331b8e535e5b..80dabeebf580 100644 --- a/drivers/media/v4l2-core/Kconfig +++ b/drivers/media/v4l2-core/Kconfig @@ -40,6 +40,11 @@ config VIDEO_TUNER config V4L2_JPEG_HELPER tristate +config V4L2_LOOPBACK + tristate "V4L2 loopback device" + help + V4L2 loopback device + # Used by drivers that need v4l2-h264.ko config V4L2_H264 tristate diff --git a/drivers/media/v4l2-core/Makefile b/drivers/media/v4l2-core/Makefile index 2177b9d63a8f..c179507cedc4 100644 --- a/drivers/media/v4l2-core/Makefile +++ b/drivers/media/v4l2-core/Makefile @@ -33,5 +33,7 @@ obj-$(CONFIG_V4L2_JPEG_HELPER) += v4l2-jpeg.o obj-$(CONFIG_V4L2_MEM2MEM_DEV) += v4l2-mem2mem.o obj-$(CONFIG_V4L2_VP9) += v4l2-vp9.o +obj-$(CONFIG_V4L2_LOOPBACK) += v4l2loopback.o + obj-$(CONFIG_VIDEO_TUNER) += tuner.o obj-$(CONFIG_VIDEO_DEV) += v4l2-dv-timings.o videodev.o diff --git a/drivers/media/v4l2-core/v4l2loopback.c b/drivers/media/v4l2-core/v4l2loopback.c new file mode 100644 index 000000000000..468d47f4950b --- /dev/null +++ b/drivers/media/v4l2-core/v4l2loopback.c @@ -0,0 +1,3292 @@ +/* -*- c-file-style: "linux" -*- */ +/* + * v4l2loopback.c -- video4linux2 loopback driver + * + * Copyright (C) 2005-2009 Vasily Levin (vasaka@gmail.com) + * Copyright (C) 2010-2023 IOhannes m zmoelnig (zmoelnig@iem.at) + * Copyright (C) 2011 Stefan Diewald (stefan.diewald@mytum.de) + * Copyright (C) 2012 Anton Novikov (random.plant@gmail.com) + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "v4l2loopback.h" + +#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 0, 0) +#error This module is not supported on kernels before 4.0.0. +#endif + +#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 3, 0) +#define strscpy strlcpy +#endif + +#if defined(timer_setup) && defined(from_timer) +#define HAVE_TIMER_SETUP +#endif + +#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 7, 0) +#define VFL_TYPE_VIDEO VFL_TYPE_GRABBER +#endif + +#define V4L2LOOPBACK_VERSION_CODE \ + KERNEL_VERSION(V4L2LOOPBACK_VERSION_MAJOR, V4L2LOOPBACK_VERSION_MINOR, \ + V4L2LOOPBACK_VERSION_BUGFIX) + +MODULE_DESCRIPTION("V4L2 loopback video device"); +MODULE_AUTHOR("Vasily Levin, " + "IOhannes m zmoelnig ," + "Stefan Diewald," + "Anton Novikov" + "et al."); +#ifdef SNAPSHOT_VERSION +MODULE_VERSION(__stringify(SNAPSHOT_VERSION)); +#else +MODULE_VERSION("" __stringify(V4L2LOOPBACK_VERSION_MAJOR) "." __stringify( + V4L2LOOPBACK_VERSION_MINOR) "." __stringify(V4L2LOOPBACK_VERSION_BUGFIX)); +#endif +MODULE_LICENSE("GPL"); + +/* + * helpers + */ +#define dprintk(fmt, args...) \ + do { \ + if (debug > 0) { \ + printk(KERN_INFO "v4l2-loopback[" __stringify( \ + __LINE__) "], pid(%d): " fmt, \ + task_pid_nr(current), ##args); \ + } \ + } while (0) + +#define MARK() \ + do { \ + if (debug > 1) { \ + printk(KERN_INFO "%s:%d[%s], pid(%d)\n", __FILE__, \ + __LINE__, __func__, task_pid_nr(current)); \ + } \ + } while (0) + +#define dprintkrw(fmt, args...) \ + do { \ + if (debug > 2) { \ + printk(KERN_INFO "v4l2-loopback[" __stringify( \ + __LINE__) "], pid(%d): " fmt, \ + task_pid_nr(current), ##args); \ + } \ + } while (0) + +static inline void v4l2l_get_timestamp(struct v4l2_buffer *b) +{ + struct timespec64 ts; + ktime_get_ts64(&ts); + + b->timestamp.tv_sec = ts.tv_sec; + b->timestamp.tv_usec = (ts.tv_nsec / NSEC_PER_USEC); + b->flags |= V4L2_BUF_FLAG_TIMESTAMP_MONOTONIC; + b->flags &= ~V4L2_BUF_FLAG_TIMESTAMP_COPY; +} + +#if BITS_PER_LONG == 32 +#include /* do_div() for 64bit division */ +static inline int v4l2l_mod64(const s64 A, const u32 B) +{ + u64 a = (u64)A; + u32 b = B; + + if (A > 0) + return do_div(a, b); + a = -A; + return -do_div(a, b); +} +#else +static inline int v4l2l_mod64(const s64 A, const u32 B) +{ + return A % B; +} +#endif + +#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 16, 0) +typedef unsigned __poll_t; +#endif + +/* module constants + * can be overridden during he build process using something like + * make KCPPFLAGS="-DMAX_DEVICES=100" + */ + +/* maximum number of v4l2loopback devices that can be created */ +#ifndef MAX_DEVICES +#define MAX_DEVICES 8 +#endif + +/* whether the default is to announce capabilities exclusively or not */ +#ifndef V4L2LOOPBACK_DEFAULT_EXCLUSIVECAPS +#define V4L2LOOPBACK_DEFAULT_EXCLUSIVECAPS 0 +#endif + +/* when a producer is considered to have gone stale */ +#ifndef MAX_TIMEOUT +#define MAX_TIMEOUT (100 * 1000) /* in msecs */ +#endif + +/* max buffers that can be mapped, actually they + * are all mapped to max_buffers buffers */ +#ifndef MAX_BUFFERS +#define MAX_BUFFERS 32 +#endif + +/* module parameters */ +static int debug = 0; +module_param(debug, int, S_IRUGO | S_IWUSR); +MODULE_PARM_DESC(debug, "debugging level (higher values == more verbose)"); + +#define V4L2LOOPBACK_DEFAULT_MAX_BUFFERS 2 +static int max_buffers = V4L2LOOPBACK_DEFAULT_MAX_BUFFERS; +module_param(max_buffers, int, S_IRUGO); +MODULE_PARM_DESC(max_buffers, + "how many buffers should be allocated [DEFAULT: " __stringify( + V4L2LOOPBACK_DEFAULT_MAX_BUFFERS) "]"); + +/* how many times a device can be opened + * the per-module default value can be overridden on a per-device basis using + * the /sys/devices interface + * + * note that max_openers should be at least 2 in order to get a working system: + * one opener for the producer and one opener for the consumer + * however, we leave that to the user + */ +#define V4L2LOOPBACK_DEFAULT_MAX_OPENERS 10 +static int max_openers = V4L2LOOPBACK_DEFAULT_MAX_OPENERS; +module_param(max_openers, int, S_IRUGO | S_IWUSR); +MODULE_PARM_DESC( + max_openers, + "how many users can open the loopback device [DEFAULT: " __stringify( + V4L2LOOPBACK_DEFAULT_MAX_OPENERS) "]"); + +static int devices = -1; +module_param(devices, int, 0); +MODULE_PARM_DESC(devices, "how many devices should be created"); + +static int video_nr[MAX_DEVICES] = { [0 ...(MAX_DEVICES - 1)] = -1 }; +module_param_array(video_nr, int, NULL, 0444); +MODULE_PARM_DESC(video_nr, + "video device numbers (-1=auto, 0=/dev/video0, etc.)"); + +static char *card_label[MAX_DEVICES]; +module_param_array(card_label, charp, NULL, 0000); +MODULE_PARM_DESC(card_label, "card labels for each device"); + +static bool exclusive_caps[MAX_DEVICES] = { + [0 ...(MAX_DEVICES - 1)] = V4L2LOOPBACK_DEFAULT_EXCLUSIVECAPS +}; +module_param_array(exclusive_caps, bool, NULL, 0444); +/* FIXXME: wording */ +MODULE_PARM_DESC( + exclusive_caps, + "whether to announce OUTPUT/CAPTURE capabilities exclusively or not [DEFAULT: " __stringify( + V4L2LOOPBACK_DEFAULT_EXCLUSIVECAPS) "]"); + +/* format specifications */ +#define V4L2LOOPBACK_SIZE_MIN_WIDTH 2 +#define V4L2LOOPBACK_SIZE_MIN_HEIGHT 1 +#define V4L2LOOPBACK_SIZE_DEFAULT_MAX_WIDTH 8192 +#define V4L2LOOPBACK_SIZE_DEFAULT_MAX_HEIGHT 8192 + +#define V4L2LOOPBACK_SIZE_DEFAULT_WIDTH 640 +#define V4L2LOOPBACK_SIZE_DEFAULT_HEIGHT 480 + +static int max_width = V4L2LOOPBACK_SIZE_DEFAULT_MAX_WIDTH; +module_param(max_width, int, S_IRUGO); +MODULE_PARM_DESC(max_width, + "maximum allowed frame width [DEFAULT: " __stringify( + V4L2LOOPBACK_SIZE_DEFAULT_MAX_WIDTH) "]"); +static int max_height = V4L2LOOPBACK_SIZE_DEFAULT_MAX_HEIGHT; +module_param(max_height, int, S_IRUGO); +MODULE_PARM_DESC(max_height, + "maximum allowed frame height [DEFAULT: " __stringify( + V4L2LOOPBACK_SIZE_DEFAULT_MAX_HEIGHT) "]"); + +static DEFINE_IDR(v4l2loopback_index_idr); +static DEFINE_MUTEX(v4l2loopback_ctl_mutex); + +/* frame intervals */ +#define V4L2LOOPBACK_FRAME_INTERVAL_MAX __UINT32_MAX__ +#define V4L2LOOPBACK_FPS_DEFAULT 30 +#define V4L2LOOPBACK_FPS_MAX 1000 + +/* control IDs */ +#define V4L2LOOPBACK_CID_BASE (V4L2_CID_USER_BASE | 0xf000) +#define CID_KEEP_FORMAT (V4L2LOOPBACK_CID_BASE + 0) +#define CID_SUSTAIN_FRAMERATE (V4L2LOOPBACK_CID_BASE + 1) +#define CID_TIMEOUT (V4L2LOOPBACK_CID_BASE + 2) +#define CID_TIMEOUT_IMAGE_IO (V4L2LOOPBACK_CID_BASE + 3) + +static int v4l2loopback_s_ctrl(struct v4l2_ctrl *ctrl); +static const struct v4l2_ctrl_ops v4l2loopback_ctrl_ops = { + .s_ctrl = v4l2loopback_s_ctrl, +}; +static const struct v4l2_ctrl_config v4l2loopback_ctrl_keepformat = { + // clang-format off + .ops = &v4l2loopback_ctrl_ops, + .id = CID_KEEP_FORMAT, + .name = "keep_format", + .type = V4L2_CTRL_TYPE_BOOLEAN, + .min = 0, + .max = 1, + .step = 1, + .def = 0, + // clang-format on +}; +static const struct v4l2_ctrl_config v4l2loopback_ctrl_sustainframerate = { + // clang-format off + .ops = &v4l2loopback_ctrl_ops, + .id = CID_SUSTAIN_FRAMERATE, + .name = "sustain_framerate", + .type = V4L2_CTRL_TYPE_BOOLEAN, + .min = 0, + .max = 1, + .step = 1, + .def = 0, + // clang-format on +}; +static const struct v4l2_ctrl_config v4l2loopback_ctrl_timeout = { + // clang-format off + .ops = &v4l2loopback_ctrl_ops, + .id = CID_TIMEOUT, + .name = "timeout", + .type = V4L2_CTRL_TYPE_INTEGER, + .min = 0, + .max = MAX_TIMEOUT, + .step = 1, + .def = 0, + // clang-format on +}; +static const struct v4l2_ctrl_config v4l2loopback_ctrl_timeoutimageio = { + // clang-format off + .ops = &v4l2loopback_ctrl_ops, + .id = CID_TIMEOUT_IMAGE_IO, + .name = "timeout_image_io", + .type = V4L2_CTRL_TYPE_BUTTON, + .min = 0, + .max = 0, + .step = 0, + .def = 0, + // clang-format on +}; + +/* module structures */ +struct v4l2loopback_private { + int device_nr; +}; + +/* TODO(vasaka) use typenames which are common to kernel, but first find out if + * it is needed */ +/* struct keeping state and settings of loopback device */ + +struct v4l2l_buffer { + struct v4l2_buffer buffer; + struct list_head list_head; + atomic_t use_count; +}; + +struct v4l2_loopback_device { + struct v4l2_device v4l2_dev; + struct v4l2_ctrl_handler ctrl_handler; + struct video_device *vdev; + + /* loopback device-specific parameters */ + char card_label[32]; + bool announce_all_caps; /* announce both OUTPUT and CAPTURE capabilities + * when true; else announce OUTPUT when no + * writer is streaming, otherwise CAPTURE. */ + int max_openers; /* how many times can this device be opened */ + int min_width, max_width; + int min_height, max_height; + + /* pixel and stream format */ + struct v4l2_pix_format pix_format; + bool pix_format_has_valid_sizeimage; + struct v4l2_captureparm capture_param; + unsigned long frame_jiffies; + + /* ctrls */ + int keep_format; /* CID_KEEP_FORMAT; lock the format, do not free + * on close(), and when `!announce_all_caps` do NOT + * fall back to OUTPUT when no writers attached (clear + * `keep_format` to attach a new writer) */ + int sustain_framerate; /* CID_SUSTAIN_FRAMERATE; duplicate frames to maintain + (close to) nominal framerate */ + unsigned long timeout_jiffies; /* CID_TIMEOUT; 0 means disabled */ + int timeout_image_io; /* CID_TIMEOUT_IMAGE_IO; next opener will + * queue/dequeue the timeout image buffer */ + + /* buffers for OUTPUT and CAPTURE */ + u8 *image; /* pointer to actual buffers data */ + unsigned long image_size; /* number of bytes alloc'd for all buffers */ + struct v4l2l_buffer buffers[MAX_BUFFERS]; /* inner driver buffers */ + u32 buffer_count; /* should not be big, 4 is a good choice */ + u32 buffer_size; /* number of bytes alloc'd per buffer */ + u32 used_buffer_count; /* number of buffers allocated to openers */ + struct list_head outbufs_list; /* FIFO queue for OUTPUT buffers */ + u32 bufpos2index[MAX_BUFFERS]; /* mapping of `(position % used_buffers)` + * to `buffers[index]` */ + s64 write_position; /* sequence number of last 'displayed' buffer plus + * one */ + + /* synchronization between openers */ + atomic_t open_count; + struct mutex image_mutex; /* mutex for allocating image(s) and + * exchanging format tokens */ + spinlock_t lock; /* lock for the timeout and framerate timers */ + spinlock_t list_lock; /* lock for the OUTPUT buffer queue */ + wait_queue_head_t read_event; + u32 format_tokens; /* tokens to 'set format' for OUTPUT, CAPTURE, or + * timeout buffers */ + u32 stream_tokens; /* tokens to 'start' OUTPUT, CAPTURE, or timeout + * stream */ + + /* sustain framerate */ + struct timer_list sustain_timer; + unsigned int reread_count; + + /* timeout */ + u8 *timeout_image; /* copied to outgoing buffers when timeout passes */ + struct v4l2l_buffer timeout_buffer; + u32 timeout_buffer_size; /* number bytes alloc'd for timeout buffer */ + struct timer_list timeout_timer; + int timeout_happened; +}; + +enum v4l2l_io_method { + V4L2L_IO_NONE = 0, + V4L2L_IO_MMAP = 1, + V4L2L_IO_FILE = 2, + V4L2L_IO_TIMEOUT = 3, +}; + +/* struct keeping state and type of opener */ +struct v4l2_loopback_opener { + u32 format_token; /* token (if any) for type used in call to S_FMT or + * REQBUFS */ + u32 stream_token; /* token (if any) for type used in call to STREAMON */ + u32 buffer_count; /* number of buffers (if any) that opener acquired via + * REQBUFS */ + s64 read_position; /* sequence number of the next 'captured' frame */ + unsigned int reread_count; + enum v4l2l_io_method io_method; + + struct v4l2_fh fh; +}; + +#define fh_to_opener(ptr) container_of((ptr), struct v4l2_loopback_opener, fh) + +/* this is heavily inspired by the bttv driver found in the linux kernel */ +struct v4l2l_format { + char *name; + int fourcc; /* video4linux 2 */ + int depth; /* bit/pixel */ + int flags; +}; +/* set the v4l2l_format.flags to PLANAR for non-packed formats */ +#define FORMAT_FLAGS_PLANAR 0x01 +#define FORMAT_FLAGS_COMPRESSED 0x02 + +#include "v4l2loopback_formats.h" + +#ifndef V4L2_TYPE_IS_CAPTURE +#define V4L2_TYPE_IS_CAPTURE(type) \ + ((type) == V4L2_BUF_TYPE_VIDEO_CAPTURE || \ + (type) == V4L2_BUF_TYPE_VIDEO_CAPTURE_MPLANE) +#endif /* V4L2_TYPE_IS_CAPTURE */ +#ifndef V4L2_TYPE_IS_OUTPUT +#define V4L2_TYPE_IS_OUTPUT(type) \ + ((type) == V4L2_BUF_TYPE_VIDEO_OUTPUT || \ + (type) == V4L2_BUF_TYPE_VIDEO_OUTPUT_MPLANE) +#endif /* V4L2_TYPE_IS_OUTPUT */ + +/* token values for privilege to set format or start/stop stream */ +#define V4L2L_TOKEN_CAPTURE 0x01 +#define V4L2L_TOKEN_OUTPUT 0x02 +#define V4L2L_TOKEN_TIMEOUT 0x04 +#define V4L2L_TOKEN_MASK \ + (V4L2L_TOKEN_CAPTURE | V4L2L_TOKEN_OUTPUT | V4L2L_TOKEN_TIMEOUT) + +/* helpers for token exchange and token status */ +#define token_from_type(type) \ + (V4L2_TYPE_IS_CAPTURE(type) ? V4L2L_TOKEN_CAPTURE : V4L2L_TOKEN_OUTPUT) +#define acquire_token(dev, opener, label, token) \ + do { \ + (opener)->label##_token = token; \ + (dev)->label##_tokens &= ~token; \ + } while (0) +#define release_token(dev, opener, label) \ + do { \ + (dev)->label##_tokens |= (opener)->label##_token; \ + (opener)->label##_token = 0; \ + } while (0) +#define has_output_token(token) (token & V4L2L_TOKEN_OUTPUT) +#define has_capture_token(token) (token & V4L2L_TOKEN_CAPTURE) +#define has_no_owners(dev) ((~((dev)->format_tokens) & V4L2L_TOKEN_MASK) == 0) +#define has_other_owners(opener, dev) \ + (~((dev)->format_tokens ^ (opener)->format_token) & V4L2L_TOKEN_MASK) +#define need_timeout_buffer(dev, token) \ + ((dev)->timeout_jiffies > 0 || (token) & V4L2L_TOKEN_TIMEOUT) + +static const unsigned int FORMATS = ARRAY_SIZE(formats); + +static char *fourcc2str(unsigned int fourcc, char buf[5]) +{ + buf[0] = (fourcc >> 0) & 0xFF; + buf[1] = (fourcc >> 8) & 0xFF; + buf[2] = (fourcc >> 16) & 0xFF; + buf[3] = (fourcc >> 24) & 0xFF; + buf[4] = 0; + + return buf; +} + +static const struct v4l2l_format *format_by_fourcc(int fourcc) +{ + unsigned int i; + char buf[5]; + + for (i = 0; i < FORMATS; i++) { + if (formats[i].fourcc == fourcc) + return formats + i; + } + + dprintk("unsupported format '%4s'\n", fourcc2str(fourcc, buf)); + return NULL; +} + +static void pix_format_set_size(struct v4l2_pix_format *f, + const struct v4l2l_format *fmt, + unsigned int width, unsigned int height) +{ + f->width = width; + f->height = height; + + if (fmt->flags & FORMAT_FLAGS_PLANAR) { + f->bytesperline = width; /* Y plane */ + f->sizeimage = (width * height * fmt->depth) >> 3; + } else if (fmt->flags & FORMAT_FLAGS_COMPRESSED) { + /* doesn't make sense for compressed formats */ + f->bytesperline = 0; + f->sizeimage = (width * height * fmt->depth) >> 3; + } else { + f->bytesperline = (width * fmt->depth) >> 3; + f->sizeimage = height * f->bytesperline; + } +} + +static int v4l2l_fill_format(struct v4l2_format *fmt, const u32 minwidth, + const u32 maxwidth, const u32 minheight, + const u32 maxheight) +{ + u32 width = fmt->fmt.pix.width, height = fmt->fmt.pix.height; + u32 pixelformat = fmt->fmt.pix.pixelformat; + struct v4l2_format fmt0 = *fmt; + u32 bytesperline = 0, sizeimage = 0; + + if (!width) + width = V4L2LOOPBACK_SIZE_DEFAULT_WIDTH; + if (!height) + height = V4L2LOOPBACK_SIZE_DEFAULT_HEIGHT; + width = clamp_val(width, minwidth, maxwidth); + height = clamp_val(height, minheight, maxheight); + + /* sets: width,height,pixelformat,bytesperline,sizeimage */ + if (!(V4L2_TYPE_IS_MULTIPLANAR(fmt0.type))) { + fmt0.fmt.pix.bytesperline = 0; + fmt0.fmt.pix.sizeimage = 0; + } + + if (0) { + ; +#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 2, 0) + } else if (!v4l2_fill_pixfmt(&fmt0.fmt.pix, pixelformat, width, + height)) { + ; + } else if (!v4l2_fill_pixfmt_mp(&fmt0.fmt.pix_mp, pixelformat, width, + height)) { + ; +#endif + } else { + const struct v4l2l_format *format = + format_by_fourcc(pixelformat); + if (!format) + return -EINVAL; + pix_format_set_size(&fmt0.fmt.pix, format, width, height); + fmt0.fmt.pix.pixelformat = format->fourcc; + } + + if (V4L2_TYPE_IS_MULTIPLANAR(fmt0.type)) { + *fmt = fmt0; + + if ((fmt->fmt.pix_mp.colorspace == V4L2_COLORSPACE_DEFAULT) || + (fmt->fmt.pix_mp.colorspace > V4L2_COLORSPACE_DCI_P3)) + fmt->fmt.pix_mp.colorspace = V4L2_COLORSPACE_SRGB; + if (V4L2_FIELD_ANY == fmt->fmt.pix_mp.field) + fmt->fmt.pix_mp.field = V4L2_FIELD_NONE; + } else { + bytesperline = fmt->fmt.pix.bytesperline; + sizeimage = fmt->fmt.pix.sizeimage; + + *fmt = fmt0; + + if (!fmt->fmt.pix.bytesperline) + fmt->fmt.pix.bytesperline = bytesperline; + if (!fmt->fmt.pix.sizeimage) + fmt->fmt.pix.sizeimage = sizeimage; + + if ((fmt->fmt.pix.colorspace == V4L2_COLORSPACE_DEFAULT) || + (fmt->fmt.pix.colorspace > V4L2_COLORSPACE_DCI_P3)) + fmt->fmt.pix.colorspace = V4L2_COLORSPACE_SRGB; + if (V4L2_FIELD_ANY == fmt->fmt.pix.field) + fmt->fmt.pix.field = V4L2_FIELD_NONE; + } + + return 0; +} + +/* Checks if v4l2l_fill_format() has set a valid, fixed sizeimage val. */ +static bool v4l2l_pix_format_has_valid_sizeimage(struct v4l2_format *fmt) +{ +#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 2, 0) + const struct v4l2_format_info *info; + + info = v4l2_format_info(fmt->fmt.pix.pixelformat); + if (info && info->mem_planes == 1) + return true; +#endif + + return false; +} + +static int pix_format_eq(const struct v4l2_pix_format *ref, + const struct v4l2_pix_format *tgt, int strict) +{ + /* check if the two formats are equivalent. + * ANY fields are handled gracefully + */ +#define _pix_format_eq0(x) \ + if (ref->x != tgt->x) \ + result = 0 +#define _pix_format_eq1(x, def) \ + do { \ + if ((def != tgt->x) && (ref->x != tgt->x)) { \ + printk(KERN_INFO #x " failed"); \ + result = 0; \ + } \ + } while (0) + int result = 1; + _pix_format_eq0(width); + _pix_format_eq0(height); + _pix_format_eq0(pixelformat); + if (!strict) + return result; + _pix_format_eq1(field, V4L2_FIELD_ANY); + _pix_format_eq0(bytesperline); + _pix_format_eq0(sizeimage); + _pix_format_eq1(colorspace, V4L2_COLORSPACE_DEFAULT); + return result; +} + +static void set_timeperframe(struct v4l2_loopback_device *dev, + struct v4l2_fract *tpf) +{ + if (!tpf->denominator && !tpf->numerator) { + tpf->numerator = 1; + tpf->denominator = V4L2LOOPBACK_FPS_DEFAULT; + } else if (tpf->numerator > + V4L2LOOPBACK_FRAME_INTERVAL_MAX * tpf->denominator) { + /* divide-by-zero or greater than maximum interval => min FPS */ + tpf->numerator = V4L2LOOPBACK_FRAME_INTERVAL_MAX; + tpf->denominator = 1; + } else if (tpf->numerator * V4L2LOOPBACK_FPS_MAX < tpf->denominator) { + /* zero or lower than minimum interval => max FPS */ + tpf->numerator = 1; + tpf->denominator = V4L2LOOPBACK_FPS_MAX; + } + + dev->capture_param.timeperframe = *tpf; + dev->frame_jiffies = + max(1UL, (msecs_to_jiffies(1000) * tpf->numerator) / + tpf->denominator); +} + +static struct v4l2_loopback_device *v4l2loopback_cd2dev(struct device *cd); + +/* device attributes */ +/* available via sysfs: /sys/devices/virtual/video4linux/video* */ + +static ssize_t attr_show_format(struct device *cd, + struct device_attribute *attr, char *buf) +{ + /* gets the current format as "FOURCC:WxH@f/s", e.g. "YUYV:320x240@1000/30" */ + struct v4l2_loopback_device *dev = v4l2loopback_cd2dev(cd); + const struct v4l2_fract *tpf; + char buf4cc[5], buf_fps[32]; + + if (!dev || (has_no_owners(dev) && !dev->keep_format)) + return 0; + tpf = &dev->capture_param.timeperframe; + + fourcc2str(dev->pix_format.pixelformat, buf4cc); + if (tpf->numerator == 1) + snprintf(buf_fps, sizeof(buf_fps), "%u", tpf->denominator); + else + snprintf(buf_fps, sizeof(buf_fps), "%u/%u", tpf->denominator, + tpf->numerator); + return sprintf(buf, "%4s:%ux%u@%s\n", buf4cc, dev->pix_format.width, + dev->pix_format.height, buf_fps); +} + +static ssize_t attr_store_format(struct device *cd, + struct device_attribute *attr, const char *buf, + size_t len) +{ + struct v4l2_loopback_device *dev = v4l2loopback_cd2dev(cd); + int fps_num = 0, fps_den = 1; + + if (!dev) + return -ENODEV; + + /* only fps changing is supported */ + if (sscanf(buf, "@%u/%u", &fps_num, &fps_den) > 0) { + struct v4l2_fract f = { .numerator = fps_den, + .denominator = fps_num }; + set_timeperframe(dev, &f); + return len; + } + return -EINVAL; +} + +static DEVICE_ATTR(format, S_IRUGO | S_IWUSR, attr_show_format, + attr_store_format); + +static ssize_t attr_show_buffers(struct device *cd, + struct device_attribute *attr, char *buf) +{ + struct v4l2_loopback_device *dev = v4l2loopback_cd2dev(cd); + + if (!dev) + return -ENODEV; + + return sprintf(buf, "%u\n", dev->used_buffer_count); +} + +static DEVICE_ATTR(buffers, S_IRUGO, attr_show_buffers, NULL); + +static ssize_t attr_show_maxopeners(struct device *cd, + struct device_attribute *attr, char *buf) +{ + struct v4l2_loopback_device *dev = v4l2loopback_cd2dev(cd); + + if (!dev) + return -ENODEV; + + return sprintf(buf, "%d\n", dev->max_openers); +} + +static ssize_t attr_store_maxopeners(struct device *cd, + struct device_attribute *attr, + const char *buf, size_t len) +{ + struct v4l2_loopback_device *dev = NULL; + unsigned long curr = 0; + + if (kstrtoul(buf, 0, &curr)) + return -EINVAL; + + dev = v4l2loopback_cd2dev(cd); + if (!dev) + return -ENODEV; + + if (dev->max_openers == curr) + return len; + + if (curr > __INT_MAX__ || dev->open_count.counter > curr) { + /* request to limit to less openers as are currently attached to us */ + return -EINVAL; + } + + dev->max_openers = (int)curr; + + return len; +} + +static DEVICE_ATTR(max_openers, S_IRUGO | S_IWUSR, attr_show_maxopeners, + attr_store_maxopeners); + +static ssize_t attr_show_state(struct device *cd, struct device_attribute *attr, + char *buf) +{ + struct v4l2_loopback_device *dev = v4l2loopback_cd2dev(cd); + + if (!dev) + return -ENODEV; + + if (!has_output_token(dev->stream_tokens) || dev->keep_format) { + return sprintf(buf, "capture\n"); + } else + return sprintf(buf, "output\n"); + + return -EAGAIN; +} + +static DEVICE_ATTR(state, S_IRUGO, attr_show_state, NULL); + +static void v4l2loopback_remove_sysfs(struct video_device *vdev) +{ +#define V4L2_SYSFS_DESTROY(x) device_remove_file(&vdev->dev, &dev_attr_##x) + + if (vdev) { + V4L2_SYSFS_DESTROY(format); + V4L2_SYSFS_DESTROY(buffers); + V4L2_SYSFS_DESTROY(max_openers); + V4L2_SYSFS_DESTROY(state); + /* ... */ + } +} + +static void v4l2loopback_create_sysfs(struct video_device *vdev) +{ + int res = 0; + +#define V4L2_SYSFS_CREATE(x) \ + res = device_create_file(&vdev->dev, &dev_attr_##x); \ + if (res < 0) \ + break + if (!vdev) + return; + do { + V4L2_SYSFS_CREATE(format); + V4L2_SYSFS_CREATE(buffers); + V4L2_SYSFS_CREATE(max_openers); + V4L2_SYSFS_CREATE(state); + /* ... */ + } while (0); + + if (res >= 0) + return; + dev_err(&vdev->dev, "%s error: %d\n", __func__, res); +} + +/* Event APIs */ + +#define V4L2LOOPBACK_EVENT_BASE (V4L2_EVENT_PRIVATE_START) +#define V4L2LOOPBACK_EVENT_OFFSET 0x08E00000 +#define V4L2_EVENT_PRI_CLIENT_USAGE \ + (V4L2LOOPBACK_EVENT_BASE + V4L2LOOPBACK_EVENT_OFFSET + 1) + +struct v4l2_event_client_usage { + __u32 count; +}; + +/* global module data */ +/* find a device based on it's device-number (e.g. '3' for /dev/video3) */ +struct v4l2loopback_lookup_cb_data { + int device_nr; + struct v4l2_loopback_device *device; +}; +static int v4l2loopback_lookup_cb(int id, void *ptr, void *data) +{ + struct v4l2_loopback_device *device = ptr; + struct v4l2loopback_lookup_cb_data *cbdata = data; + if (cbdata && device && device->vdev) { + if (device->vdev->num == cbdata->device_nr) { + cbdata->device = device; + cbdata->device_nr = id; + return 1; + } + } + return 0; +} +static int v4l2loopback_lookup(int device_nr, + struct v4l2_loopback_device **device) +{ + struct v4l2loopback_lookup_cb_data data = { + .device_nr = device_nr, + .device = NULL, + }; + int err = idr_for_each(&v4l2loopback_index_idr, &v4l2loopback_lookup_cb, + &data); + if (1 == err) { + if (device) + *device = data.device; + return data.device_nr; + } + return -ENODEV; +} +#define v4l2loopback_get_vdev_nr(vdev) \ + ((struct v4l2loopback_private *)video_get_drvdata(vdev))->device_nr +static struct v4l2_loopback_device *v4l2loopback_cd2dev(struct device *cd) +{ + struct video_device *loopdev = to_video_device(cd); + int device_nr = v4l2loopback_get_vdev_nr(loopdev); + + return idr_find(&v4l2loopback_index_idr, device_nr); +} + +static struct v4l2_loopback_device *v4l2loopback_getdevice(struct file *f) +{ + struct v4l2loopback_private *ptr = video_drvdata(f); + int nr = ptr->device_nr; + + return idr_find(&v4l2loopback_index_idr, nr); +} + +/* forward declarations */ +static void client_usage_queue_event(struct video_device *vdev); +static bool any_buffers_mapped(struct v4l2_loopback_device *dev); +static int allocate_buffers(struct v4l2_loopback_device *dev, + struct v4l2_pix_format *pix_format); +static void init_buffers(struct v4l2_loopback_device *dev, u32 bytes_used, + u32 buffer_size); +static void free_buffers(struct v4l2_loopback_device *dev); +static int allocate_timeout_buffer(struct v4l2_loopback_device *dev); +static void free_timeout_buffer(struct v4l2_loopback_device *dev); +static void check_timers(struct v4l2_loopback_device *dev); +static const struct v4l2_file_operations v4l2_loopback_fops; +static const struct v4l2_ioctl_ops v4l2_loopback_ioctl_ops; + +/* V4L2 ioctl caps and params calls */ +/* returns device capabilities + * called on VIDIOC_QUERYCAP + */ +static int vidioc_querycap(struct file *file, void *fh, + struct v4l2_capability *cap) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + int device_nr = v4l2loopback_get_vdev_nr(dev->vdev); + __u32 capabilities = V4L2_CAP_STREAMING | V4L2_CAP_READWRITE; + + strscpy(cap->driver, "v4l2 loopback", sizeof(cap->driver)); + snprintf(cap->card, sizeof(cap->card), "%s", dev->card_label); + snprintf(cap->bus_info, sizeof(cap->bus_info), + "platform:v4l2loopback-%03d", device_nr); + + if (dev->announce_all_caps) { + capabilities |= V4L2_CAP_VIDEO_CAPTURE | V4L2_CAP_VIDEO_OUTPUT; + } else { + if (opener->io_method == V4L2L_IO_TIMEOUT || + (has_output_token(dev->stream_tokens) && + !dev->keep_format)) { + capabilities |= V4L2_CAP_VIDEO_OUTPUT; + } else + capabilities |= V4L2_CAP_VIDEO_CAPTURE; + } + +#if LINUX_VERSION_CODE >= KERNEL_VERSION(4, 7, 0) + dev->vdev->device_caps = +#endif /* >=linux-4.7.0 */ + cap->device_caps = cap->capabilities = capabilities; + + cap->capabilities |= V4L2_CAP_DEVICE_CAPS; + + memset(cap->reserved, 0, sizeof(cap->reserved)); + return 0; +} + +static int vidioc_enum_framesizes(struct file *file, void *fh, + struct v4l2_frmsizeenum *argp) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + + /* there can be only one... */ + if (argp->index) + return -EINVAL; + + if (dev->keep_format || has_other_owners(opener, dev)) { + /* only current frame size supported */ + if (argp->pixel_format != dev->pix_format.pixelformat) + return -EINVAL; + + argp->type = V4L2_FRMSIZE_TYPE_DISCRETE; + + argp->discrete.width = dev->pix_format.width; + argp->discrete.height = dev->pix_format.height; + } else { + /* return continuous sizes if pixel format is supported */ + if (NULL == format_by_fourcc(argp->pixel_format)) + return -EINVAL; + + if (dev->min_width == dev->max_width && + dev->min_height == dev->max_height) { + argp->type = V4L2_FRMSIZE_TYPE_DISCRETE; + + argp->discrete.width = dev->min_width; + argp->discrete.height = dev->min_height; + } else { + argp->type = V4L2_FRMSIZE_TYPE_CONTINUOUS; + + argp->stepwise.min_width = dev->min_width; + argp->stepwise.min_height = dev->min_height; + + argp->stepwise.max_width = dev->max_width; + argp->stepwise.max_height = dev->max_height; + + argp->stepwise.step_width = 1; + argp->stepwise.step_height = 1; + } + } + return 0; +} + +/* Test if the device is currently 'capable' of the buffer (stream) type when + * the `exclusive_caps` parameter is set. `keep_format` should lock the format + * and prevent free of buffers */ +static int check_buffer_capability(struct v4l2_loopback_device *dev, + struct v4l2_loopback_opener *opener, + enum v4l2_buf_type type) +{ + /* short-circuit for (non-compliant) timeout image mode */ + if (opener->io_method == V4L2L_IO_TIMEOUT) + return 0; + if (dev->announce_all_caps) + return (type == V4L2_BUF_TYPE_VIDEO_CAPTURE || + type == V4L2_BUF_TYPE_VIDEO_OUTPUT) ? + 0 : + -EINVAL; + /* CAPTURE if opener has a capture format or a writer is streaming; + * else OUTPUT. */ + switch (type) { + case V4L2_BUF_TYPE_VIDEO_CAPTURE: + if (!(has_capture_token(opener->format_token) || + !has_output_token(dev->stream_tokens))) + return -EINVAL; + break; + case V4L2_BUF_TYPE_VIDEO_OUTPUT: + if (!(has_output_token(opener->format_token) || + has_output_token(dev->stream_tokens))) + return -EINVAL; + break; + default: + return -EINVAL; + } + return 0; +} +/* returns frameinterval (fps) for the set resolution + * called on VIDIOC_ENUM_FRAMEINTERVALS + */ +static int vidioc_enum_frameintervals(struct file *file, void *fh, + struct v4l2_frmivalenum *argp) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + + /* there can be only one... */ + if (argp->index) + return -EINVAL; + + if (dev->keep_format || has_other_owners(opener, dev)) { + /* keep_format also locks the frame rate */ + if (argp->width != dev->pix_format.width || + argp->height != dev->pix_format.height || + argp->pixel_format != dev->pix_format.pixelformat) + return -EINVAL; + + argp->type = V4L2_FRMIVAL_TYPE_DISCRETE; + argp->discrete = dev->capture_param.timeperframe; + } else { + if (argp->width < dev->min_width || + argp->width > dev->max_width || + argp->height < dev->min_height || + argp->height > dev->max_height || + !format_by_fourcc(argp->pixel_format)) + return -EINVAL; + + argp->type = V4L2_FRMIVAL_TYPE_CONTINUOUS; + argp->stepwise.min.numerator = 1; + argp->stepwise.min.denominator = V4L2LOOPBACK_FPS_MAX; + argp->stepwise.max.numerator = V4L2LOOPBACK_FRAME_INTERVAL_MAX; + argp->stepwise.max.denominator = 1; + argp->stepwise.step.numerator = 1; + argp->stepwise.step.denominator = 1; + } + + return 0; +} + +/* Enumerate device formats + * Returns: + * - EINVAL the index is out of bounds; or if non-zero when format is fixed + * - EFAULT unexpected null pointer */ +static int vidioc_enum_fmt_vid(struct file *file, void *fh, + struct v4l2_fmtdesc *f) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + int fixed = dev->keep_format || has_other_owners(opener, dev); + const struct v4l2l_format *fmt; + + if (check_buffer_capability(dev, opener, f->type) < 0) + return -EINVAL; + + if (!(f->index < FORMATS)) + return -EINVAL; + /* TODO: Support 6.14 V4L2_FMTDESC_FLAG_ENUM_ALL */ + if (fixed && f->index) + return -EINVAL; + + fmt = fixed ? format_by_fourcc(dev->pix_format.pixelformat) : + &formats[f->index]; + if (!fmt) + return -EFAULT; + + f->flags = 0; + if (fmt->flags & FORMAT_FLAGS_COMPRESSED) + f->flags |= V4L2_FMT_FLAG_COMPRESSED; + snprintf(f->description, sizeof(f->description), fmt->name); + f->pixelformat = fmt->fourcc; + return 0; +} + +/* Tests (or tries) the format. + * Returns: + * - EINVAL if the buffer type or format is not supported + */ +static int vidioc_try_fmt_vid(struct file *file, void *fh, + struct v4l2_format *f) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + + if (check_buffer_capability(dev, opener, f->type) < 0) + return -EINVAL; + if (v4l2l_fill_format(f, dev->min_width, dev->max_width, + dev->min_height, dev->max_height) != 0) + return -EINVAL; + if (dev->keep_format || has_other_owners(opener, dev)) + /* use existing format - including colorspace info */ + f->fmt.pix = dev->pix_format; + + return 0; +} + +/* Sets new format. Fills 'f' argument with the requested or existing format. + * Side-effect: buffers are allocated for the (returned) format. + * Returns: + * - EINVAL if the type is not supported + * - EBUSY if buffers are already allocated + * TODO: (vasaka) set subregions of input + */ +static int vidioc_s_fmt_vid(struct file *file, void *fh, struct v4l2_format *f) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + u32 token = opener->io_method == V4L2L_IO_TIMEOUT ? + V4L2L_TOKEN_TIMEOUT : + token_from_type(f->type); + int changed, result; + char buf[5]; + + result = vidioc_try_fmt_vid(file, fh, f); + if (result < 0) + return result; + + if (opener->buffer_count > 0) + /* must free buffers before format can be set */ + return -EBUSY; + + result = mutex_lock_killable(&dev->image_mutex); + if (result < 0) + return result; + + if (opener->format_token) + release_token(dev, opener, format); + if (!(dev->format_tokens & token)) { + result = -EBUSY; + goto exit_s_fmt_unlock; + } + + dprintk("S_FMT[%s] %4s:%ux%u size=%u\n", + V4L2_TYPE_IS_CAPTURE(f->type) ? "CAPTURE" : "OUTPUT", + fourcc2str(f->fmt.pix.pixelformat, buf), f->fmt.pix.width, + f->fmt.pix.height, f->fmt.pix.sizeimage); + changed = !pix_format_eq(&dev->pix_format, &f->fmt.pix, 0); + if (changed || has_no_owners(dev)) { + result = allocate_buffers(dev, &f->fmt.pix); + if (result < 0) + goto exit_s_fmt_unlock; + } + if ((dev->timeout_image && changed) || + (!dev->timeout_image && need_timeout_buffer(dev, token))) { + result = allocate_timeout_buffer(dev); + if (result < 0) + goto exit_s_fmt_free; + } + if (changed) { + dev->pix_format = f->fmt.pix; + dev->pix_format_has_valid_sizeimage = + v4l2l_pix_format_has_valid_sizeimage(f); + } + acquire_token(dev, opener, format, token); + if (opener->io_method == V4L2L_IO_TIMEOUT) + dev->timeout_image_io = 0; + goto exit_s_fmt_unlock; +exit_s_fmt_free: + free_buffers(dev); +exit_s_fmt_unlock: + mutex_unlock(&dev->image_mutex); + return result; +} + +/* ------------------ CAPTURE ----------------------- */ +/* ioctl for VIDIOC_ENUM_FMT, _G_FMT, _S_FMT, and _TRY_FMT when buffer type + * is V4L2_BUF_TYPE_VIDEO_CAPTURE */ + +static int vidioc_enum_fmt_cap(struct file *file, void *fh, + struct v4l2_fmtdesc *f) +{ + return vidioc_enum_fmt_vid(file, fh, f); +} + +static int vidioc_g_fmt_cap(struct file *file, void *fh, struct v4l2_format *f) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + if (check_buffer_capability(dev, opener, f->type) < 0) + return -EINVAL; + f->fmt.pix = dev->pix_format; + return 0; +} + +static int vidioc_try_fmt_cap(struct file *file, void *fh, + struct v4l2_format *f) +{ + return vidioc_try_fmt_vid(file, fh, f); +} + +static int vidioc_s_fmt_cap(struct file *file, void *fh, struct v4l2_format *f) +{ + return vidioc_s_fmt_vid(file, fh, f); +} + +/* ------------------ OUTPUT ----------------------- */ +/* ioctl for VIDIOC_ENUM_FMT, _G_FMT, _S_FMT, and _TRY_FMT when buffer type + * is V4L2_BUF_TYPE_VIDEO_OUTPUT */ + +static int vidioc_enum_fmt_out(struct file *file, void *fh, + struct v4l2_fmtdesc *f) +{ + return vidioc_enum_fmt_vid(file, fh, f); +} + +static int vidioc_g_fmt_out(struct file *file, void *fh, struct v4l2_format *f) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + if (check_buffer_capability(dev, opener, f->type) < 0) + return -EINVAL; + /* + * LATER: this should return the currently valid format + * gstreamer doesn't like it, if this returns -EINVAL, as it + * then concludes that there is _no_ valid format + * CHECK whether this assumption is wrong, + * or whether we have to always provide a valid format + */ + f->fmt.pix = dev->pix_format; + return 0; +} + +static int vidioc_try_fmt_out(struct file *file, void *fh, + struct v4l2_format *f) +{ + return vidioc_try_fmt_vid(file, fh, f); +} + +static int vidioc_s_fmt_out(struct file *file, void *fh, struct v4l2_format *f) +{ + return vidioc_s_fmt_vid(file, fh, f); +} + +// #define V4L2L_OVERLAY +#ifdef V4L2L_OVERLAY +/* ------------------ OVERLAY ----------------------- */ +/* currently unsupported */ +/* GSTreamer's v4l2sink is buggy, as it requires the overlay to work + * while it should only require it, if overlay is requested + * once the gstreamer element is fixed, remove the overlay dummies + */ +#warning OVERLAY dummies +static int vidioc_g_fmt_overlay(struct file *file, void *priv, + struct v4l2_format *fmt) +{ + return 0; +} + +static int vidioc_s_fmt_overlay(struct file *file, void *priv, + struct v4l2_format *fmt) +{ + return 0; +} +#endif /* V4L2L_OVERLAY */ + +/* ------------------ PARAMs ----------------------- */ + +/* get some data flow parameters, only capability, fps and readbuffers has + * effect on this driver + * called on VIDIOC_G_PARM + */ +static int vidioc_g_parm(struct file *file, void *fh, + struct v4l2_streamparm *parm) +{ + /* do not care about type of opener, hope these enums would always be + * compatible */ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + if (check_buffer_capability(dev, opener, parm->type) < 0) + return -EINVAL; + parm->parm.capture = dev->capture_param; + return 0; +} + +/* get some data flow parameters, only capability, fps and readbuffers has + * effect on this driver + * called on VIDIOC_S_PARM + */ +static int vidioc_s_parm(struct file *file, void *fh, + struct v4l2_streamparm *parm) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + + dprintk("S_PARM(frame-time=%u/%u)\n", + parm->parm.capture.timeperframe.numerator, + parm->parm.capture.timeperframe.denominator); + if (check_buffer_capability(dev, opener, parm->type) < 0) + return -EINVAL; + + switch (parm->type) { + case V4L2_BUF_TYPE_VIDEO_CAPTURE: + set_timeperframe(dev, &parm->parm.capture.timeperframe); + break; + case V4L2_BUF_TYPE_VIDEO_OUTPUT: + set_timeperframe(dev, &parm->parm.output.timeperframe); + break; + default: + return -EINVAL; + } + + parm->parm.capture = dev->capture_param; + return 0; +} + +#ifdef V4L2LOOPBACK_WITH_STD +/* sets a tv standard, actually we do not need to handle this any special way + * added to support effecttv + * called on VIDIOC_S_STD + */ +static int vidioc_s_std(struct file *file, void *fh, v4l2_std_id *_std) +{ + v4l2_std_id req_std = 0, supported_std = 0; + const v4l2_std_id all_std = V4L2_STD_ALL, no_std = 0; + + if (_std) { + req_std = *_std; + *_std = all_std; + } + + /* we support everything in V4L2_STD_ALL, but not more... */ + supported_std = (all_std & req_std); + if (no_std == supported_std) + return -EINVAL; + + return 0; +} + +/* gets a fake video standard + * called on VIDIOC_G_STD + */ +static int vidioc_g_std(struct file *file, void *fh, v4l2_std_id *norm) +{ + if (norm) + *norm = V4L2_STD_ALL; + return 0; +} +/* gets a fake video standard + * called on VIDIOC_QUERYSTD + */ +static int vidioc_querystd(struct file *file, void *fh, v4l2_std_id *norm) +{ + if (norm) + *norm = V4L2_STD_ALL; + return 0; +} +#endif /* V4L2LOOPBACK_WITH_STD */ + +static int v4l2loopback_set_ctrl(struct v4l2_loopback_device *dev, u32 id, + s64 val) +{ + int result = 0; + switch (id) { + case CID_KEEP_FORMAT: + if (val < 0 || val > 1) + return -EINVAL; + dev->keep_format = val; + result = mutex_lock_killable(&dev->image_mutex); + if (result < 0) + return result; + if (!dev->keep_format) { + if (has_no_owners(dev) && !any_buffers_mapped(dev)) + free_buffers(dev); + } + mutex_unlock(&dev->image_mutex); + break; + case CID_SUSTAIN_FRAMERATE: + if (val < 0 || val > 1) + return -EINVAL; + spin_lock_bh(&dev->lock); + dev->sustain_framerate = val; + check_timers(dev); + spin_unlock_bh(&dev->lock); + break; + case CID_TIMEOUT: + if (val < 0 || val > MAX_TIMEOUT) + return -EINVAL; + if (val > 0) { + result = mutex_lock_killable(&dev->image_mutex); + if (result < 0) + return result; + /* on-the-fly allocate if device is owned; else + * allocate occurs on next S_FMT or REQBUFS */ + if (!has_no_owners(dev)) + result = allocate_timeout_buffer(dev); + mutex_unlock(&dev->image_mutex); + if (result < 0) { + /* disable timeout as buffer not alloc'd */ + spin_lock_bh(&dev->lock); + dev->timeout_jiffies = 0; + spin_unlock_bh(&dev->lock); + return result; + } + } + spin_lock_bh(&dev->lock); + dev->timeout_jiffies = msecs_to_jiffies(val); + check_timers(dev); + spin_unlock_bh(&dev->lock); + break; + case CID_TIMEOUT_IMAGE_IO: + dev->timeout_image_io = 1; + break; + default: + return -EINVAL; + } + return 0; +} + +static int v4l2loopback_s_ctrl(struct v4l2_ctrl *ctrl) +{ + struct v4l2_loopback_device *dev = container_of( + ctrl->handler, struct v4l2_loopback_device, ctrl_handler); + return v4l2loopback_set_ctrl(dev, ctrl->id, ctrl->val); +} + +/* returns set of device outputs, in our case there is only one + * called on VIDIOC_ENUMOUTPUT + */ +static int vidioc_enum_output(struct file *file, void *fh, + struct v4l2_output *outp) +{ + __u32 index = outp->index; + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + + if (check_buffer_capability(dev, opener, V4L2_BUF_TYPE_VIDEO_OUTPUT)) + return -ENOTTY; + if (index) + return -EINVAL; + + /* clear all data (including the reserved fields) */ + memset(outp, 0, sizeof(*outp)); + + outp->index = index; + strscpy(outp->name, "loopback in", sizeof(outp->name)); + outp->type = V4L2_OUTPUT_TYPE_ANALOG; + outp->audioset = 0; + outp->modulator = 0; +#ifdef V4L2LOOPBACK_WITH_STD + outp->std = V4L2_STD_ALL; +#ifdef V4L2_OUT_CAP_STD + outp->capabilities |= V4L2_OUT_CAP_STD; +#endif /* V4L2_OUT_CAP_STD */ +#endif /* V4L2LOOPBACK_WITH_STD */ + + return 0; +} + +/* which output is currently active, + * called on VIDIOC_G_OUTPUT + */ +static int vidioc_g_output(struct file *file, void *fh, unsigned int *index) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + if (check_buffer_capability(dev, opener, V4L2_BUF_TYPE_VIDEO_OUTPUT)) + return -ENOTTY; + if (index) + *index = 0; + return 0; +} + +/* set output, can make sense if we have more than one video src, + * called on VIDIOC_S_OUTPUT + */ +static int vidioc_s_output(struct file *file, void *fh, unsigned int index) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + if (check_buffer_capability(dev, opener, V4L2_BUF_TYPE_VIDEO_OUTPUT)) + return -ENOTTY; + return index == 0 ? index : -EINVAL; +} + +/* returns set of device inputs, in our case there is only one, + * but later I may add more + * called on VIDIOC_ENUMINPUT + */ +static int vidioc_enum_input(struct file *file, void *fh, + struct v4l2_input *inp) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + __u32 index = inp->index; + + if (check_buffer_capability(dev, opener, V4L2_BUF_TYPE_VIDEO_CAPTURE)) + return -ENOTTY; + if (index) + return -EINVAL; + + /* clear all data (including the reserved fields) */ + memset(inp, 0, sizeof(*inp)); + + inp->index = index; + strscpy(inp->name, "loopback", sizeof(inp->name)); + inp->type = V4L2_INPUT_TYPE_CAMERA; + inp->audioset = 0; + inp->tuner = 0; + inp->status = 0; + +#ifdef V4L2LOOPBACK_WITH_STD + inp->std = V4L2_STD_ALL; +#ifdef V4L2_IN_CAP_STD + inp->capabilities |= V4L2_IN_CAP_STD; +#endif +#endif /* V4L2LOOPBACK_WITH_STD */ + + if (has_output_token(dev->stream_tokens) && !dev->keep_format) + /* if no outputs attached; pretend device is powered off */ + inp->status |= V4L2_IN_ST_NO_SIGNAL; + + return 0; +} + +/* which input is currently active, + * called on VIDIOC_G_INPUT + */ +static int vidioc_g_input(struct file *file, void *fh, unsigned int *index) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + if (check_buffer_capability(dev, opener, V4L2_BUF_TYPE_VIDEO_CAPTURE)) + return -ENOTTY; /* NOTE: -EAGAIN might be more informative */ + if (index) + *index = 0; + return 0; +} + +/* set input, can make sense if we have more than one video src, + * called on VIDIOC_S_INPUT + */ +static int vidioc_s_input(struct file *file, void *fh, unsigned int index) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + if (index != 0) + return -EINVAL; + if (check_buffer_capability(dev, opener, V4L2_BUF_TYPE_VIDEO_CAPTURE)) + return -ENOTTY; /* NOTE: -EAGAIN might be more informative */ + return 0; +} + +/* --------------- V4L2 ioctl buffer related calls ----------------- */ + +#define is_allocated(opener, type, index) \ + (opener->format_token & (opener->io_method == V4L2L_IO_TIMEOUT ? \ + V4L2L_TOKEN_TIMEOUT : \ + token_from_type(type)) && \ + (index) < (opener)->buffer_count) +#define BUFFER_DEBUG_FMT_STR \ + "buffer#%u @ %p type=%u bytesused=%u length=%u flags=%x " \ + "field=%u timestamp= %lld.%06lldsequence=%u\n" +#define BUFFER_DEBUG_FMT_ARGS(buf) \ + (buf)->index, (buf), (buf)->type, (buf)->bytesused, (buf)->length, \ + (buf)->flags, (buf)->field, \ + (long long)(buf)->timestamp.tv_sec, \ + (long long)(buf)->timestamp.tv_usec, (buf)->sequence +/* Buffer flag helpers */ +#define unset_flags(flags) \ + do { \ + flags &= ~V4L2_BUF_FLAG_QUEUED; \ + flags &= ~V4L2_BUF_FLAG_DONE; \ + } while (0) +#define set_queued(flags) \ + do { \ + flags |= V4L2_BUF_FLAG_QUEUED; \ + flags &= ~V4L2_BUF_FLAG_DONE; \ + } while (0) +#define set_done(flags) \ + do { \ + flags &= ~V4L2_BUF_FLAG_QUEUED; \ + flags |= V4L2_BUF_FLAG_DONE; \ + } while (0) + +static bool any_buffers_mapped(struct v4l2_loopback_device *dev) +{ + u32 index; + for (index = 0; index < dev->buffer_count; ++index) + if (dev->buffers[index].buffer.flags & V4L2_BUF_FLAG_MAPPED) + return true; + return false; +} + +static void prepare_buffer_queue(struct v4l2_loopback_device *dev, int count) +{ + struct v4l2l_buffer *bufd, *n; + u32 pos; + + spin_lock_bh(&dev->list_lock); + + /* ensure sufficient number of buffers in queue */ + for (pos = 0; pos < count; ++pos) { + bufd = &dev->buffers[pos]; + if (list_empty(&bufd->list_head)) + list_add_tail(&bufd->list_head, &dev->outbufs_list); + } + if (list_empty(&dev->outbufs_list)) + goto exit_prepare_queue_unlock; + + /* remove any excess buffers */ + list_for_each_entry_safe(bufd, n, &dev->outbufs_list, list_head) { + if (bufd->buffer.index >= count) + list_del_init(&bufd->list_head); + } + + /* buffers are no longer queued; and `write_position` will correspond + * to the first item of `outbufs_list`. */ + pos = v4l2l_mod64(dev->write_position, count); + list_for_each_entry(bufd, &dev->outbufs_list, list_head) { + unset_flags(bufd->buffer.flags); + dev->bufpos2index[pos % count] = bufd->buffer.index; + ++pos; + } +exit_prepare_queue_unlock: + spin_unlock_bh(&dev->list_lock); +} + +/* forward declaration */ +static int vidioc_streamoff(struct file *file, void *fh, + enum v4l2_buf_type type); +/* negotiate buffer type + * only mmap streaming supported + * called on VIDIOC_REQBUFS + */ +static int vidioc_reqbufs(struct file *file, void *fh, + struct v4l2_requestbuffers *reqbuf) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + u32 token = opener->io_method == V4L2L_IO_TIMEOUT ? + V4L2L_TOKEN_TIMEOUT : + token_from_type(reqbuf->type); + u32 req_count = reqbuf->count; + int result = 0; + + dprintk("REQBUFS(memory=%u, req_count=%u) and device-bufs=%u/%u " + "[used/max]\n", + reqbuf->memory, req_count, dev->used_buffer_count, + dev->buffer_count); + + switch (reqbuf->memory) { + case V4L2_MEMORY_MMAP: +#if LINUX_VERSION_CODE >= KERNEL_VERSION(4, 20, 0) + reqbuf->capabilities = 0; /* only guarantee MMAP support */ +#endif +#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 16, 0) + reqbuf->flags = 0; /* no memory consistency support */ +#endif + break; + default: + return -EINVAL; + } + + if (opener->format_token & ~token) + /* different (buffer) type already assigned to descriptor by + * S_FMT or REQBUFS */ + return -EINVAL; + + MARK(); + result = mutex_lock_killable(&dev->image_mutex); + if (result < 0) + return result; /* -EINTR */ + + /* CASE queue/dequeue timeout-buffer only: */ + if (opener->format_token & V4L2L_TOKEN_TIMEOUT) { + opener->buffer_count = req_count; + if (req_count == 0) + release_token(dev, opener, format); + goto exit_reqbufs_unlock; + } + + MARK(); + /* CASE count is zero: streamoff, free buffers, release their token */ + if (req_count == 0) { + if (dev->format_tokens & token) { + acquire_token(dev, opener, format, token); + opener->io_method = V4L2L_IO_MMAP; + } + result = vidioc_streamoff(file, fh, reqbuf->type); + opener->buffer_count = 0; + /* undocumented requirement - REQBUFS with count zero should + * ALSO release lock on logical stream */ + if (opener->format_token) + release_token(dev, opener, format); + if (has_no_owners(dev)) + dev->used_buffer_count = 0; + goto exit_reqbufs_unlock; + } + + /* CASE count non-zero: allocate buffers and acquire token for them */ + MARK(); + switch (reqbuf->type) { + case V4L2_BUF_TYPE_VIDEO_CAPTURE: + case V4L2_BUF_TYPE_VIDEO_OUTPUT: + if (!(dev->format_tokens & token || + opener->format_token & token)) + /* only exclusive ownership for each stream */ + result = -EBUSY; + break; + default: + result = -EINVAL; + } + if (result < 0) + goto exit_reqbufs_unlock; + + if (has_other_owners(opener, dev) && dev->used_buffer_count > 0) { + /* allow 'allocation' of existing number of buffers */ + req_count = dev->used_buffer_count; + } else if (any_buffers_mapped(dev)) { + /* do not allow re-allocation if buffers are mapped */ + result = -EBUSY; + goto exit_reqbufs_unlock; + } + + MARK(); + opener->buffer_count = 0; + + if (req_count > dev->buffer_count) + req_count = dev->buffer_count; + + if (has_no_owners(dev)) { + result = allocate_buffers(dev, &dev->pix_format); + if (result < 0) + goto exit_reqbufs_unlock; + } + if (!dev->timeout_image && need_timeout_buffer(dev, token)) { + result = allocate_timeout_buffer(dev); + if (result < 0) + goto exit_reqbufs_unlock; + } + acquire_token(dev, opener, format, token); + + MARK(); + switch (opener->io_method) { + case V4L2L_IO_TIMEOUT: + dev->timeout_image_io = 0; + opener->buffer_count = req_count; + break; + default: + opener->io_method = V4L2L_IO_MMAP; + prepare_buffer_queue(dev, req_count); + dev->used_buffer_count = opener->buffer_count = req_count; + } +exit_reqbufs_unlock: + mutex_unlock(&dev->image_mutex); + reqbuf->count = opener->buffer_count; + return result; +} + +/* returns buffer asked for; + * give app as many buffers as it wants, if it less than MAX, + * but map them in our inner buffers + * called on VIDIOC_QUERYBUF + */ +static int vidioc_querybuf(struct file *file, void *fh, struct v4l2_buffer *buf) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + u32 type = buf->type; + u32 index = buf->index; + + if ((type != V4L2_BUF_TYPE_VIDEO_CAPTURE) && + (type != V4L2_BUF_TYPE_VIDEO_OUTPUT)) + return -EINVAL; + if (!is_allocated(opener, type, index)) + return -EINVAL; + + if (opener->format_token & V4L2L_TOKEN_TIMEOUT) { + *buf = dev->timeout_buffer.buffer; + buf->index = index; + } else + *buf = dev->buffers[index].buffer; + + buf->type = type; + + if (!(buf->flags & (V4L2_BUF_FLAG_DONE | V4L2_BUF_FLAG_QUEUED))) { + /* v4l2-compliance requires these to be zero */ + buf->sequence = 0; + buf->timestamp.tv_sec = buf->timestamp.tv_usec = 0; + } else if (V4L2_TYPE_IS_CAPTURE(type)) { + /* guess flags based on sequence values */ + if (buf->sequence >= opener->read_position) { + set_done(buf->flags); + } else if (buf->flags & V4L2_BUF_FLAG_DONE) { + set_queued(buf->flags); + } + } + dprintkrw("QUERYBUF(%s, index=%u) -> " BUFFER_DEBUG_FMT_STR, + V4L2_TYPE_IS_CAPTURE(type) ? "CAPTURE" : "OUTPUT", index, + BUFFER_DEBUG_FMT_ARGS(buf)); + return 0; +} + +static void buffer_written(struct v4l2_loopback_device *dev, + struct v4l2l_buffer *buf) +{ + del_timer_sync(&dev->sustain_timer); + del_timer_sync(&dev->timeout_timer); + + spin_lock_bh(&dev->list_lock); + list_move_tail(&buf->list_head, &dev->outbufs_list); + spin_unlock_bh(&dev->list_lock); + + spin_lock_bh(&dev->lock); + dev->bufpos2index[v4l2l_mod64(dev->write_position, + dev->used_buffer_count)] = + buf->buffer.index; + ++dev->write_position; + dev->reread_count = 0; + + check_timers(dev); + spin_unlock_bh(&dev->lock); +} + +/* put buffer to queue + * called on VIDIOC_QBUF + */ +static int vidioc_qbuf(struct file *file, void *fh, struct v4l2_buffer *buf) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + struct v4l2l_buffer *bufd; + u32 index = buf->index; + u32 type = buf->type; + + if (!is_allocated(opener, type, index)) + return -EINVAL; + bufd = &dev->buffers[index]; + + switch (buf->memory) { + case V4L2_MEMORY_MMAP: + if (!(bufd->buffer.flags & V4L2_BUF_FLAG_MAPPED)) + dprintkrw("QBUF() unmapped buffer [index=%u]\n", index); + break; + default: + return -EINVAL; + } + + if (opener->format_token & V4L2L_TOKEN_TIMEOUT) { + set_queued(buf->flags); + return 0; + } + + switch (type) { + case V4L2_BUF_TYPE_VIDEO_CAPTURE: + dprintkrw("QBUF(CAPTURE, index=%u) -> " BUFFER_DEBUG_FMT_STR, + index, BUFFER_DEBUG_FMT_ARGS(buf)); + set_queued(buf->flags); + break; + case V4L2_BUF_TYPE_VIDEO_OUTPUT: + dprintkrw("QBUF(OUTPUT, index=%u) -> " BUFFER_DEBUG_FMT_STR, + index, BUFFER_DEBUG_FMT_ARGS(buf)); + if (!(bufd->buffer.flags & V4L2_BUF_FLAG_TIMESTAMP_COPY) && + (buf->timestamp.tv_sec == 0 && + buf->timestamp.tv_usec == 0)) { + v4l2l_get_timestamp(&bufd->buffer); + } else { + bufd->buffer.timestamp = buf->timestamp; + bufd->buffer.flags |= V4L2_BUF_FLAG_TIMESTAMP_COPY; + bufd->buffer.flags &= + ~V4L2_BUF_FLAG_TIMESTAMP_MONOTONIC; + } + if (dev->pix_format_has_valid_sizeimage) { + if (buf->bytesused >= dev->pix_format.sizeimage) { + bufd->buffer.bytesused = + dev->pix_format.sizeimage; + } else { +#if LINUX_VERSION_CODE >= KERNEL_VERSION(3, 5, 0) + dev_warn_ratelimited( + &dev->vdev->dev, +#else + dprintkrw( +#endif + "warning queued output buffer bytesused too small %u < %u\n", + buf->bytesused, + dev->pix_format.sizeimage); + bufd->buffer.bytesused = buf->bytesused; + } + } else { + bufd->buffer.bytesused = buf->bytesused; + } + bufd->buffer.sequence = dev->write_position; + set_queued(bufd->buffer.flags); + *buf = bufd->buffer; + buffer_written(dev, bufd); + set_done(bufd->buffer.flags); + wake_up_all(&dev->read_event); + break; + default: + return -EINVAL; + } + buf->type = type; + return 0; +} + +static int can_read(struct v4l2_loopback_device *dev, + struct v4l2_loopback_opener *opener) +{ + int ret; + + spin_lock_bh(&dev->lock); + check_timers(dev); + ret = dev->write_position > opener->read_position || + dev->reread_count > opener->reread_count || dev->timeout_happened; + spin_unlock_bh(&dev->lock); + return ret; +} + +static int get_capture_buffer(struct file *file) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(file->private_data); + int pos, timeout_happened; + u32 index; + + if ((file->f_flags & O_NONBLOCK) && + (dev->write_position <= opener->read_position && + dev->reread_count <= opener->reread_count && + !dev->timeout_happened)) + return -EAGAIN; + wait_event_interruptible(dev->read_event, can_read(dev, opener)); + + spin_lock_bh(&dev->lock); + if (dev->write_position == opener->read_position) { + if (dev->reread_count > opener->reread_count + 2) + opener->reread_count = dev->reread_count - 1; + ++opener->reread_count; + pos = v4l2l_mod64(opener->read_position + + dev->used_buffer_count - 1, + dev->used_buffer_count); + } else { + opener->reread_count = 0; + if (dev->write_position > + opener->read_position + dev->used_buffer_count) + opener->read_position = dev->write_position - 1; + pos = v4l2l_mod64(opener->read_position, + dev->used_buffer_count); + ++opener->read_position; + } + timeout_happened = dev->timeout_happened && (dev->timeout_jiffies > 0); + dev->timeout_happened = 0; + spin_unlock_bh(&dev->lock); + + index = dev->bufpos2index[pos]; + if (timeout_happened) { + if (index >= dev->used_buffer_count) { + dprintkrw("get_capture_buffer() read position is at " + "an unallocated buffer [index=%u]\n", + index); + return -EFAULT; + } + /* although allocated on-demand, timeout_image is freed only + * in free_buffers(), so we don't need to worry about it being + * deallocated suddenly */ + memcpy(dev->image + dev->buffers[index].buffer.m.offset, + dev->timeout_image, dev->buffer_size); + } + return (int)index; +} + +/* put buffer to dequeue + * called on VIDIOC_DQBUF + */ +static int vidioc_dqbuf(struct file *file, void *fh, struct v4l2_buffer *buf) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + u32 type = buf->type; + int index; + struct v4l2l_buffer *bufd; + + if (buf->memory != V4L2_MEMORY_MMAP) + return -EINVAL; + if (opener->format_token & V4L2L_TOKEN_TIMEOUT) { + *buf = dev->timeout_buffer.buffer; + buf->type = type; + unset_flags(buf->flags); + return 0; + } + if ((opener->buffer_count == 0) || + !(opener->format_token & token_from_type(type))) + return -EINVAL; + + switch (type) { + case V4L2_BUF_TYPE_VIDEO_CAPTURE: + index = get_capture_buffer(file); + if (index < 0) + return index; + *buf = dev->buffers[index].buffer; + unset_flags(buf->flags); + break; + case V4L2_BUF_TYPE_VIDEO_OUTPUT: + spin_lock_bh(&dev->list_lock); + + bufd = list_first_entry_or_null(&dev->outbufs_list, + struct v4l2l_buffer, list_head); + if (bufd) + list_move_tail(&bufd->list_head, &dev->outbufs_list); + + spin_unlock_bh(&dev->list_lock); + if (!bufd) + return -EFAULT; + unset_flags(bufd->buffer.flags); + *buf = bufd->buffer; + break; + default: + return -EINVAL; + } + + buf->type = type; + dprintkrw("DQBUF(%s, index=%u) -> " BUFFER_DEBUG_FMT_STR, + V4L2_TYPE_IS_CAPTURE(type) ? "CAPTURE" : "OUTPUT", index, + BUFFER_DEBUG_FMT_ARGS(buf)); + return 0; +} + +/* ------------- STREAMING ------------------- */ + +/* start streaming + * called on VIDIOC_STREAMON + */ +static int vidioc_streamon(struct file *file, void *fh, enum v4l2_buf_type type) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + u32 token = token_from_type(type); + + /* short-circuit when using timeout buffer set */ + if (opener->format_token & V4L2L_TOKEN_TIMEOUT) + return 0; + /* opener must have claimed (same) buffer set via REQBUFS */ + if (!opener->buffer_count || !(opener->format_token & token)) + return -EINVAL; + + switch (type) { + case V4L2_BUF_TYPE_VIDEO_CAPTURE: + if (has_output_token(dev->stream_tokens) && !dev->keep_format) + return -EIO; + if (dev->stream_tokens & token) { + acquire_token(dev, opener, stream, token); + client_usage_queue_event(dev->vdev); + } + return 0; + case V4L2_BUF_TYPE_VIDEO_OUTPUT: + if (dev->stream_tokens & token) + acquire_token(dev, opener, stream, token); + return 0; + default: + return -EINVAL; + } +} + +/* stop streaming + * called on VIDIOC_STREAMOFF + */ +static int vidioc_streamoff(struct file *file, void *fh, + enum v4l2_buf_type type) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + u32 token = token_from_type(type); + + /* short-circuit when using timeout buffer set */ + if (opener->format_token & V4L2L_TOKEN_TIMEOUT) + return 0; + /* short-circuit when buffer set has no owner */ + if (dev->format_tokens & token) + return 0; + /* opener needs a claim to buffer set */ + if (!opener->format_token) + return -EBUSY; + if (opener->format_token & ~token) + return -EINVAL; + + switch (type) { + case V4L2_BUF_TYPE_VIDEO_OUTPUT: + if (opener->stream_token & token) + release_token(dev, opener, stream); + /* reset output queue */ + if (dev->used_buffer_count > 0) + prepare_buffer_queue(dev, dev->used_buffer_count); + return 0; + case V4L2_BUF_TYPE_VIDEO_CAPTURE: + if (opener->stream_token & token) { + release_token(dev, opener, stream); + client_usage_queue_event(dev->vdev); + } + return 0; + default: + return -EINVAL; + } +} + +#ifdef CONFIG_VIDEO_V4L1_COMPAT +static int vidiocgmbuf(struct file *file, void *fh, struct video_mbuf *p) +{ + struct v4l2_loopback_device *dev; + MARK(); + + dev = v4l2loopback_getdevice(file); + p->frames = dev->buffer_count; + p->offsets[0] = 0; + p->offsets[1] = 0; + p->size = dev->buffer_size; + return 0; +} +#endif + +static void client_usage_queue_event(struct video_device *vdev) +{ + struct v4l2_event ev; + struct v4l2_loopback_device *dev; + + dev = container_of(vdev->v4l2_dev, struct v4l2_loopback_device, + v4l2_dev); + + memset(&ev, 0, sizeof(ev)); + ev.type = V4L2_EVENT_PRI_CLIENT_USAGE; + ((struct v4l2_event_client_usage *)&ev.u)->count = + !has_capture_token(dev->stream_tokens); + + v4l2_event_queue(vdev, &ev); +} + +static int client_usage_ops_add(struct v4l2_subscribed_event *sev, + unsigned elems) +{ + if (!(sev->flags & V4L2_EVENT_SUB_FL_SEND_INITIAL)) + return 0; + + client_usage_queue_event(sev->fh->vdev); + return 0; +} + +static void client_usage_ops_replace(struct v4l2_event *old, + const struct v4l2_event *new) +{ + *((struct v4l2_event_client_usage *)&old->u) = + *((struct v4l2_event_client_usage *)&new->u); +} + +static void client_usage_ops_merge(const struct v4l2_event *old, + struct v4l2_event *new) +{ + *((struct v4l2_event_client_usage *)&new->u) = + *((struct v4l2_event_client_usage *)&old->u); +} + +const struct v4l2_subscribed_event_ops client_usage_ops = { + .add = client_usage_ops_add, + .replace = client_usage_ops_replace, + .merge = client_usage_ops_merge, +}; + +static int vidioc_subscribe_event(struct v4l2_fh *fh, + const struct v4l2_event_subscription *sub) +{ + switch (sub->type) { + case V4L2_EVENT_CTRL: + return v4l2_ctrl_subscribe_event(fh, sub); + case V4L2_EVENT_PRI_CLIENT_USAGE: + return v4l2_event_subscribe(fh, sub, 0, &client_usage_ops); + } + + return -EINVAL; +} + +/* file operations */ +static void vm_open(struct vm_area_struct *vma) +{ + struct v4l2l_buffer *buf; + MARK(); + + buf = vma->vm_private_data; + atomic_inc(&buf->use_count); + buf->buffer.flags |= V4L2_BUF_FLAG_MAPPED; +} + +static void vm_close(struct vm_area_struct *vma) +{ + struct v4l2l_buffer *buf; + MARK(); + + buf = vma->vm_private_data; + if (atomic_dec_and_test(&buf->use_count)) + buf->buffer.flags &= ~V4L2_BUF_FLAG_MAPPED; +} + +static struct vm_operations_struct vm_ops = { + .open = vm_open, + .close = vm_close, +}; + +static int v4l2_loopback_mmap(struct file *file, struct vm_area_struct *vma) +{ + u8 *addr; + unsigned long start, size, offset; + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(file->private_data); + struct v4l2l_buffer *buffer = NULL; + int result = 0; + MARK(); + + offset = (unsigned long)vma->vm_pgoff << PAGE_SHIFT; + start = (unsigned long)vma->vm_start; + size = (unsigned long)(vma->vm_end - vma->vm_start); /* always != 0 */ + + /* ensure buffer size, count, and allocated image(s) are not altered by + * other file descriptors */ + result = mutex_lock_killable(&dev->image_mutex); + if (result < 0) + return result; + + if (size > dev->buffer_size) { + dprintk("mmap() attempt to map %lubytes when %ubytes are " + "allocated to buffers\n", + size, dev->buffer_size); + result = -EINVAL; + goto exit_mmap_unlock; + } + if (offset % dev->buffer_size != 0) { + dprintk("mmap() offset does not match start of any buffer\n"); + result = -EINVAL; + goto exit_mmap_unlock; + } + switch (opener->format_token) { + case V4L2L_TOKEN_TIMEOUT: + if (offset != (unsigned long)dev->buffer_size * MAX_BUFFERS) { + dprintk("mmap() incorrect offset for timeout image\n"); + result = -EINVAL; + goto exit_mmap_unlock; + } + buffer = &dev->timeout_buffer; + addr = dev->timeout_image; + break; + default: + if (offset >= dev->image_size) { + dprintk("mmap() attempt to map beyond all buffers\n"); + result = -EINVAL; + goto exit_mmap_unlock; + } + u32 index = offset / dev->buffer_size; + buffer = &dev->buffers[index]; + addr = dev->image + offset; + break; + } + + while (size > 0) { + struct page *page = vmalloc_to_page(addr); + + result = vm_insert_page(vma, start, page); + if (result < 0) + goto exit_mmap_unlock; + + start += PAGE_SIZE; + addr += PAGE_SIZE; + size -= PAGE_SIZE; + } + + vma->vm_ops = &vm_ops; + vma->vm_private_data = buffer; + + vm_open(vma); +exit_mmap_unlock: + mutex_unlock(&dev->image_mutex); + return result; +} + +static unsigned int v4l2_loopback_poll(struct file *file, + struct poll_table_struct *pts) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(file->private_data); + __poll_t req_events = poll_requested_events(pts); + int ret_mask = 0; + + /* call poll_wait in first call, regardless, to ensure that the + * wait-queue is not null */ + poll_wait(file, &dev->read_event, pts); + poll_wait(file, &opener->fh.wait, pts); + + if (req_events & POLLPRI) { + if (v4l2_event_pending(&opener->fh)) { + ret_mask |= POLLPRI; + if (!(req_events & DEFAULT_POLLMASK)) + return ret_mask; + } + } + + switch (opener->format_token) { + case V4L2L_TOKEN_OUTPUT: + if (opener->stream_token != 0 || + opener->io_method == V4L2L_IO_NONE) + ret_mask |= POLLOUT | POLLWRNORM; + break; + case V4L2L_TOKEN_CAPTURE: + if ((opener->io_method == V4L2L_IO_NONE || + opener->stream_token != 0) && + can_read(dev, opener)) + ret_mask |= POLLIN | POLLWRNORM; + break; + case V4L2L_TOKEN_TIMEOUT: + ret_mask |= POLLOUT | POLLWRNORM; + break; + default: + break; + } + + return ret_mask; +} + +/* do not want to limit device opens, it can be as many readers as user want, + * writers are limited by means of setting writer field */ +static int v4l2_loopback_open(struct file *file) +{ + struct v4l2_loopback_device *dev; + struct v4l2_loopback_opener *opener; + + dev = v4l2loopback_getdevice(file); + if (dev->open_count.counter >= dev->max_openers) + return -EBUSY; + /* kfree on close */ + opener = kzalloc(sizeof(*opener), GFP_KERNEL); + if (opener == NULL) + return -ENOMEM; + + atomic_inc(&dev->open_count); + if (dev->timeout_image_io && dev->format_tokens & V4L2L_TOKEN_TIMEOUT) + /* will clear timeout_image_io once buffer set acquired */ + opener->io_method = V4L2L_IO_TIMEOUT; + + v4l2_fh_init(&opener->fh, video_devdata(file)); + file->private_data = &opener->fh; + + v4l2_fh_add(&opener->fh); + dprintk("open() -> dev@%p with image@%p\n", dev, + dev ? dev->image : NULL); + return 0; +} + +static int v4l2_loopback_close(struct file *file) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(file->private_data); + int result = 0; + dprintk("close() -> dev@%p with image@%p\n", dev, + dev ? dev->image : NULL); + + if (opener->format_token) { + struct v4l2_requestbuffers reqbuf = { + .count = 0, .memory = V4L2_MEMORY_MMAP, .type = 0 + }; + switch (opener->format_token) { + case V4L2L_TOKEN_CAPTURE: + reqbuf.type = V4L2_BUF_TYPE_VIDEO_CAPTURE; + break; + case V4L2L_TOKEN_OUTPUT: + case V4L2L_TOKEN_TIMEOUT: + reqbuf.type = V4L2_BUF_TYPE_VIDEO_OUTPUT; + break; + } + if (reqbuf.type) + result = vidioc_reqbufs(file, file->private_data, + &reqbuf); + if (result < 0) + dprintk("failed to free buffers REQBUFS(count=0) " + " returned %d\n", + result); + mutex_lock(&dev->image_mutex); + release_token(dev, opener, format); + mutex_unlock(&dev->image_mutex); + } + + if (atomic_dec_and_test(&dev->open_count)) { + del_timer_sync(&dev->sustain_timer); + del_timer_sync(&dev->timeout_timer); + if (!dev->keep_format) { + mutex_lock(&dev->image_mutex); + free_buffers(dev); + mutex_unlock(&dev->image_mutex); + } + } + + v4l2_fh_del(&opener->fh); + v4l2_fh_exit(&opener->fh); + + kfree(opener); + return 0; +} + +static int start_fileio(struct file *file, void *fh, enum v4l2_buf_type type) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(fh); + struct v4l2_requestbuffers reqbuf = { .count = dev->buffer_count, + .memory = V4L2_MEMORY_MMAP, + .type = type }; + int token = token_from_type(type); + int result; + + if (opener->format_token & V4L2L_TOKEN_TIMEOUT || + opener->format_token & ~token) + return -EBUSY; /* NOTE: -EBADF might be more informative */ + + /* short-circuit if already have stream token */ + if (opener->stream_token && opener->io_method == V4L2L_IO_FILE) + return 0; + + /* otherwise attempt to acquire stream token and assign IO method */ + if (!(dev->stream_tokens & token) || opener->io_method != V4L2L_IO_NONE) + return -EBUSY; + + result = vidioc_reqbufs(file, fh, &reqbuf); + if (result < 0) + return result; + result = vidioc_streamon(file, fh, type); + if (result < 0) + return result; + + opener->io_method = V4L2L_IO_FILE; + return 0; +} + +static ssize_t v4l2_loopback_read(struct file *file, char __user *buf, + size_t count, loff_t *ppos) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_buffer *b; + int index, result; + + dprintkrw("read() %zu bytes\n", count); + result = start_fileio(file, file->private_data, + V4L2_BUF_TYPE_VIDEO_CAPTURE); + if (result < 0) + return result; + + index = get_capture_buffer(file); + if (index < 0) + return index; + b = &dev->buffers[index].buffer; + if (count > b->bytesused) + count = b->bytesused; + if (copy_to_user((void *)buf, (void *)(dev->image + b->m.offset), + count)) { + printk(KERN_ERR "v4l2-loopback read() failed copy_to_user()\n"); + return -EFAULT; + } + return count; +} + +static ssize_t v4l2_loopback_write(struct file *file, const char __user *buf, + size_t count, loff_t *ppos) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_buffer *b; + int index, result; + + dprintkrw("write() %zu bytes\n", count); + result = start_fileio(file, file->private_data, + V4L2_BUF_TYPE_VIDEO_OUTPUT); + if (result < 0) + return result; + + if (count > dev->buffer_size) + count = dev->buffer_size; + index = v4l2l_mod64(dev->write_position, dev->used_buffer_count); + b = &dev->buffers[index].buffer; + + if (copy_from_user((void *)(dev->image + b->m.offset), (void *)buf, + count)) { + printk(KERN_ERR + "v4l2-loopback write() failed copy_from_user()\n"); + return -EFAULT; + } + b->bytesused = count; + + v4l2l_get_timestamp(b); + b->sequence = dev->write_position; + set_queued(b->flags); + buffer_written(dev, &dev->buffers[index]); + set_done(b->flags); + wake_up_all(&dev->read_event); + + return count; +} + +/* init functions */ +/* frees buffers, if allocated */ +static void free_buffers(struct v4l2_loopback_device *dev) +{ + dprintk("free_buffers() with image@%p\n", dev->image); + if (!dev->image) + return; + if (!has_no_owners(dev) || any_buffers_mapped(dev)) + /* maybe an opener snuck in before image_mutex was acquired */ + printk(KERN_WARNING + "v4l2-loopback free_buffers() buffers of video device " + "#%u freed while still mapped to userspace\n", + dev->vdev->num); + vfree(dev->image); + dev->image = NULL; + dev->image_size = 0; + dev->buffer_size = 0; +} + +static void free_timeout_buffer(struct v4l2_loopback_device *dev) +{ + dprintk("free_timeout_buffer() with timeout_image@%p\n", + dev->timeout_image); + if (!dev->timeout_image) + return; + + if ((dev->timeout_jiffies > 0 && !has_no_owners(dev)) || + dev->timeout_buffer.buffer.flags & V4L2_BUF_FLAG_MAPPED) + printk(KERN_WARNING + "v4l2-loopback free_timeout_buffer() timeout image " + "of device #%u freed while still mapped to userspace\n", + dev->vdev->num); + + vfree(dev->timeout_image); + dev->timeout_image = NULL; + dev->timeout_buffer_size = 0; +} +/* allocates buffers if no (other) openers are already using them */ +static int allocate_buffers(struct v4l2_loopback_device *dev, + struct v4l2_pix_format *pix_format) +{ + u32 buffer_size = PAGE_ALIGN(pix_format->sizeimage); + unsigned long image_size = + (unsigned long)buffer_size * (unsigned long)dev->buffer_count; + /* vfree on close file operation in case no open handles left */ + + if (buffer_size == 0 || dev->buffer_count == 0 || + buffer_size < pix_format->sizeimage) + return -EINVAL; + + if ((__LONG_MAX__ / buffer_size) < dev->buffer_count) + return -ENOSPC; + + dprintk("allocate_buffers() size %lubytes = %ubytes x %ubuffers\n", + image_size, buffer_size, dev->buffer_count); + if (dev->image) { + /* check that no buffers are expected in user-space */ + if (!has_no_owners(dev) || any_buffers_mapped(dev)) + return -EBUSY; + dprintk("allocate_buffers() existing size=%lubytes\n", + dev->image_size); + /* FIXME: prevent double allocation more intelligently! */ + if (image_size == dev->image_size) { + dprintk("allocate_buffers() keep existing\n"); + return 0; + } + free_buffers(dev); + } + + /* FIXME: set buffers to 0 */ + dev->image = vmalloc(image_size); + if (dev->image == NULL) { + dev->buffer_size = dev->image_size = 0; + return -ENOMEM; + } + init_buffers(dev, pix_format->sizeimage, buffer_size); + dev->buffer_size = buffer_size; + dev->image_size = image_size; + dprintk("allocate_buffers() -> vmalloc'd %lubytes\n", dev->image_size); + return 0; +} +static int allocate_timeout_buffer(struct v4l2_loopback_device *dev) +{ + /* device's `buffer_size` and `buffers` must be initialised in + * allocate_buffers() */ + + dprintk("allocate_timeout_buffer() size %ubytes\n", dev->buffer_size); + if (dev->buffer_size == 0) + return -EINVAL; + + if (dev->timeout_image) { + if (dev->timeout_buffer.buffer.flags & V4L2_BUF_FLAG_MAPPED) + return -EBUSY; + if (dev->buffer_size == dev->timeout_buffer_size) + return 0; + free_timeout_buffer(dev); + } + + dev->timeout_image = vzalloc(dev->buffer_size); + if (!dev->timeout_image) { + dev->timeout_buffer_size = 0; + return -ENOMEM; + } + dev->timeout_buffer_size = dev->buffer_size; + return 0; +} +/* init inner buffers, they are capture mode and flags are set as for capture + * mode buffers */ +static void init_buffers(struct v4l2_loopback_device *dev, u32 bytes_used, + u32 buffer_size) +{ + u32 i; + + for (i = 0; i < dev->buffer_count; ++i) { + struct v4l2_buffer *b = &dev->buffers[i].buffer; + b->index = i; + b->bytesused = bytes_used; + b->length = buffer_size; + b->field = V4L2_FIELD_NONE; + b->flags = 0; + b->m.offset = i * buffer_size; + b->memory = V4L2_MEMORY_MMAP; + b->sequence = 0; + b->timestamp.tv_sec = 0; + b->timestamp.tv_usec = 0; + b->type = V4L2_BUF_TYPE_VIDEO_CAPTURE; + + v4l2l_get_timestamp(b); + } + dev->timeout_buffer = dev->buffers[0]; + dev->timeout_buffer.buffer.m.offset = MAX_BUFFERS * buffer_size; +} + +/* fills and register video device */ +static void init_vdev(struct video_device *vdev, int nr) +{ +#ifdef V4L2LOOPBACK_WITH_STD + vdev->tvnorms = V4L2_STD_ALL; +#endif /* V4L2LOOPBACK_WITH_STD */ + + vdev->vfl_type = VFL_TYPE_VIDEO; + vdev->fops = &v4l2_loopback_fops; + vdev->ioctl_ops = &v4l2_loopback_ioctl_ops; + vdev->release = &video_device_release; + vdev->minor = -1; +#if LINUX_VERSION_CODE >= KERNEL_VERSION(4, 7, 0) + vdev->device_caps = V4L2_CAP_DEVICE_CAPS | V4L2_CAP_VIDEO_CAPTURE | + V4L2_CAP_VIDEO_OUTPUT | V4L2_CAP_READWRITE | + V4L2_CAP_STREAMING; +#endif + + if (debug > 1) + vdev->dev_debug = V4L2_DEV_DEBUG_IOCTL | + V4L2_DEV_DEBUG_IOCTL_ARG; + + vdev->vfl_dir = VFL_DIR_M2M; +} + +/* init default capture parameters, only fps may be changed in future */ +static void init_capture_param(struct v4l2_captureparm *capture_param) +{ + capture_param->capability = V4L2_CAP_TIMEPERFRAME; /* since 2.16 */ + capture_param->capturemode = 0; + capture_param->extendedmode = 0; + capture_param->readbuffers = max_buffers; + capture_param->timeperframe.numerator = 1; + capture_param->timeperframe.denominator = V4L2LOOPBACK_FPS_DEFAULT; +} + +static void check_timers(struct v4l2_loopback_device *dev) +{ + if (has_output_token(dev->stream_tokens)) + return; + + if (dev->timeout_jiffies > 0 && !timer_pending(&dev->timeout_timer)) + mod_timer(&dev->timeout_timer, jiffies + dev->timeout_jiffies); + if (dev->sustain_framerate && !timer_pending(&dev->sustain_timer)) + mod_timer(&dev->sustain_timer, + jiffies + dev->frame_jiffies * 3 / 2); +} +#ifdef HAVE_TIMER_SETUP +static void sustain_timer_clb(struct timer_list *t) +{ + struct v4l2_loopback_device *dev = from_timer(dev, t, sustain_timer); +#else +static void sustain_timer_clb(unsigned long nr) +{ + struct v4l2_loopback_device *dev = + idr_find(&v4l2loopback_index_idr, nr); +#endif + spin_lock(&dev->lock); + if (dev->sustain_framerate) { + dev->reread_count++; + dprintkrw("sustain_timer_clb() write_pos=%lld reread=%u\n", + (long long)dev->write_position, dev->reread_count); + if (dev->reread_count == 1) + mod_timer(&dev->sustain_timer, + jiffies + max(1UL, dev->frame_jiffies / 2)); + else + mod_timer(&dev->sustain_timer, + jiffies + dev->frame_jiffies); + wake_up_all(&dev->read_event); + } + spin_unlock(&dev->lock); +} +#ifdef HAVE_TIMER_SETUP +static void timeout_timer_clb(struct timer_list *t) +{ + struct v4l2_loopback_device *dev = from_timer(dev, t, timeout_timer); +#else +static void timeout_timer_clb(unsigned long nr) +{ + struct v4l2_loopback_device *dev = + idr_find(&v4l2loopback_index_idr, nr); +#endif + spin_lock(&dev->lock); + if (dev->timeout_jiffies > 0) { + dev->timeout_happened = 1; + mod_timer(&dev->timeout_timer, jiffies + dev->timeout_jiffies); + wake_up_all(&dev->read_event); + } + spin_unlock(&dev->lock); +} + +/* init loopback main structure */ +#define DEFAULT_FROM_CONF(confmember, default_condition, default_value) \ + ((conf) ? \ + ((conf->confmember default_condition) ? (default_value) : \ + (conf->confmember)) : \ + default_value) + +static int v4l2_loopback_add(struct v4l2_loopback_config *conf, int *ret_nr) +{ + struct v4l2_loopback_device *dev; + struct v4l2_ctrl_handler *hdl; + struct v4l2loopback_private *vdev_priv = NULL; + int err; + + u32 _width = V4L2LOOPBACK_SIZE_DEFAULT_WIDTH; + u32 _height = V4L2LOOPBACK_SIZE_DEFAULT_HEIGHT; + + u32 _min_width = DEFAULT_FROM_CONF(min_width, + < V4L2LOOPBACK_SIZE_MIN_WIDTH, + V4L2LOOPBACK_SIZE_MIN_WIDTH); + u32 _min_height = DEFAULT_FROM_CONF(min_height, + < V4L2LOOPBACK_SIZE_MIN_HEIGHT, + V4L2LOOPBACK_SIZE_MIN_HEIGHT); + u32 _max_width = DEFAULT_FROM_CONF(max_width, < _min_width, max_width); + u32 _max_height = + DEFAULT_FROM_CONF(max_height, < _min_height, max_height); + bool _announce_all_caps = (conf && conf->announce_all_caps >= 0) ? + (bool)(conf->announce_all_caps) : + !(V4L2LOOPBACK_DEFAULT_EXCLUSIVECAPS); + int _max_buffers = DEFAULT_FROM_CONF(max_buffers, <= 0, max_buffers); + int _max_openers = DEFAULT_FROM_CONF(max_openers, <= 0, max_openers); + struct v4l2_format _fmt; + + int nr = -1; + + if (conf) { + const int output_nr = conf->output_nr; +#ifdef SPLIT_DEVICES + const int capture_nr = conf->capture_nr; +#else + const int capture_nr = output_nr; +#endif + if (capture_nr >= 0 && output_nr == capture_nr) { + nr = output_nr; + } else if (capture_nr < 0 && output_nr < 0) { + nr = -1; + } else if (capture_nr < 0) { + nr = output_nr; + } else if (output_nr < 0) { + nr = capture_nr; + } else { + printk(KERN_ERR + "v4l2-loopback add() split OUTPUT and CAPTURE " + "devices not yet supported.\n"); + printk(KERN_INFO + "v4l2-loopback add() both devices must have the " + "same number (%d != %d).\n", + output_nr, capture_nr); + return -EINVAL; + } + } + + if (idr_find(&v4l2loopback_index_idr, nr)) + return -EEXIST; + + /* initialisation of a new device */ + dprintk("add() creating device #%d\n", nr); + dev = kzalloc(sizeof(*dev), GFP_KERNEL); + if (!dev) + return -ENOMEM; + + /* allocate id, if @id >= 0, we're requesting that specific id */ + if (nr >= 0) { + err = idr_alloc(&v4l2loopback_index_idr, dev, nr, nr + 1, + GFP_KERNEL); + if (err == -ENOSPC) + err = -EEXIST; + } else { + err = idr_alloc(&v4l2loopback_index_idr, dev, 0, 0, GFP_KERNEL); + } + if (err < 0) + goto out_free_dev; + + /* register new device */ + MARK(); + nr = err; + + if (conf && conf->card_label[0]) { + snprintf(dev->card_label, sizeof(dev->card_label), "%s", + conf->card_label); + } else { + snprintf(dev->card_label, sizeof(dev->card_label), + "Dummy video device (0x%04X)", nr); + } + snprintf(dev->v4l2_dev.name, sizeof(dev->v4l2_dev.name), + "v4l2loopback-%03d", nr); + + err = v4l2_device_register(NULL, &dev->v4l2_dev); + if (err) + goto out_free_idr; + + /* initialise the _video_ device */ + MARK(); + err = -ENOMEM; + dev->vdev = video_device_alloc(); + if (dev->vdev == NULL) + goto out_unregister; + + vdev_priv = kzalloc(sizeof(struct v4l2loopback_private), GFP_KERNEL); + if (vdev_priv == NULL) + goto out_unregister; + + video_set_drvdata(dev->vdev, vdev_priv); + if (video_get_drvdata(dev->vdev) == NULL) + goto out_unregister; + + snprintf(dev->vdev->name, sizeof(dev->vdev->name), "%s", + dev->card_label); + vdev_priv->device_nr = nr; + init_vdev(dev->vdev, nr); + dev->vdev->v4l2_dev = &dev->v4l2_dev; + + /* initialise v4l2-loopback specific parameters */ + MARK(); + dev->announce_all_caps = _announce_all_caps; + dev->min_width = _min_width; + dev->min_height = _min_height; + dev->max_width = _max_width; + dev->max_height = _max_height; + dev->max_openers = _max_openers; + + /* set (initial) pixel and stream format */ + _width = clamp_val(_width, _min_width, _max_width); + _height = clamp_val(_height, _min_height, _max_height); + _fmt = (struct v4l2_format){ + .type = V4L2_BUF_TYPE_VIDEO_CAPTURE, + .fmt.pix = { .width = _width, + .height = _height, + .pixelformat = formats[0].fourcc, + .colorspace = V4L2_COLORSPACE_DEFAULT, + .field = V4L2_FIELD_NONE } + }; + + err = v4l2l_fill_format(&_fmt, _min_width, _max_width, _min_height, + _max_height); + if (err) + /* highly unexpected failure to assign default format */ + goto out_unregister; + dev->pix_format = _fmt.fmt.pix; + init_capture_param(&dev->capture_param); + set_timeperframe(dev, &dev->capture_param.timeperframe); + + /* ctrls parameters */ + dev->keep_format = 0; + dev->sustain_framerate = 0; + dev->timeout_jiffies = 0; + dev->timeout_image_io = 0; + + /* initialise OUTPUT and CAPTURE buffer values */ + dev->image = NULL; + dev->image_size = 0; + dev->buffer_count = _max_buffers; + dev->buffer_size = 0; + dev->used_buffer_count = 0; + INIT_LIST_HEAD(&dev->outbufs_list); + do { + u32 index; + for (index = 0; index < dev->buffer_count; ++index) + INIT_LIST_HEAD(&dev->buffers[index].list_head); + + } while (0); + memset(dev->bufpos2index, 0, sizeof(dev->bufpos2index)); + dev->write_position = 0; + + /* initialise synchronisation data */ + atomic_set(&dev->open_count, 0); + mutex_init(&dev->image_mutex); + spin_lock_init(&dev->lock); + spin_lock_init(&dev->list_lock); + init_waitqueue_head(&dev->read_event); + dev->format_tokens = V4L2L_TOKEN_MASK; + dev->stream_tokens = V4L2L_TOKEN_MASK; + + /* initialise sustain frame rate and timeout parameters, and timers */ + dev->reread_count = 0; + dev->timeout_image = NULL; + dev->timeout_happened = 0; +#ifdef HAVE_TIMER_SETUP + timer_setup(&dev->sustain_timer, sustain_timer_clb, 0); + timer_setup(&dev->timeout_timer, timeout_timer_clb, 0); +#else + setup_timer(&dev->sustain_timer, sustain_timer_clb, nr); + setup_timer(&dev->timeout_timer, timeout_timer_clb, nr); +#endif + + /* initialise the control handler and add controls */ + MARK(); + hdl = &dev->ctrl_handler; + err = v4l2_ctrl_handler_init(hdl, 4); + if (err) + goto out_unregister; + v4l2_ctrl_new_custom(hdl, &v4l2loopback_ctrl_keepformat, NULL); + v4l2_ctrl_new_custom(hdl, &v4l2loopback_ctrl_sustainframerate, NULL); + v4l2_ctrl_new_custom(hdl, &v4l2loopback_ctrl_timeout, NULL); + v4l2_ctrl_new_custom(hdl, &v4l2loopback_ctrl_timeoutimageio, NULL); + if (hdl->error) { + err = hdl->error; + goto out_free_handler; + } + dev->v4l2_dev.ctrl_handler = hdl; + + err = v4l2_ctrl_handler_setup(hdl); + if (err) + goto out_free_handler; + + /* register the device (creates /dev/video*) */ + MARK(); + if (video_register_device(dev->vdev, VFL_TYPE_VIDEO, nr) < 0) { + printk(KERN_ERR + "v4l2-loopback add() failed video_register_device()\n"); + err = -EFAULT; + goto out_free_device; + } + v4l2loopback_create_sysfs(dev->vdev); + /* NOTE: ambivalent if sysfs entries fail */ + + if (ret_nr) + *ret_nr = dev->vdev->num; + return 0; + +out_free_device: + video_device_release(dev->vdev); +out_free_handler: + v4l2_ctrl_handler_free(&dev->ctrl_handler); +out_unregister: + video_set_drvdata(dev->vdev, NULL); + if (vdev_priv != NULL) + kfree(vdev_priv); + v4l2_device_unregister(&dev->v4l2_dev); +out_free_idr: + idr_remove(&v4l2loopback_index_idr, nr); +out_free_dev: + kfree(dev); + return err; +} + +static void v4l2_loopback_remove(struct v4l2_loopback_device *dev) +{ + int device_nr = v4l2loopback_get_vdev_nr(dev->vdev); + mutex_lock(&dev->image_mutex); + free_buffers(dev); + free_timeout_buffer(dev); + mutex_unlock(&dev->image_mutex); + v4l2loopback_remove_sysfs(dev->vdev); + v4l2_ctrl_handler_free(&dev->ctrl_handler); + kfree(video_get_drvdata(dev->vdev)); + video_unregister_device(dev->vdev); + v4l2_device_unregister(&dev->v4l2_dev); + idr_remove(&v4l2loopback_index_idr, device_nr); + kfree(dev); +} + +static long v4l2loopback_control_ioctl(struct file *file, unsigned int cmd, + unsigned long parm) +{ + struct v4l2_loopback_device *dev; + struct v4l2_loopback_config conf; + struct v4l2_loopback_config *confptr = &conf; + int device_nr, capture_nr, output_nr; + int ret; + + ret = mutex_lock_killable(&v4l2loopback_ctl_mutex); + if (ret) + return ret; + + ret = -EINVAL; + switch (cmd) { + default: + ret = -ENOSYS; + break; + /* add a v4l2loopback device (pair), based on the user-provided specs */ + case V4L2LOOPBACK_CTL_ADD: + if (parm) { + if ((ret = copy_from_user(&conf, (void *)parm, + sizeof(conf))) < 0) + break; + } else + confptr = NULL; + ret = v4l2_loopback_add(confptr, &device_nr); + if (ret >= 0) + ret = device_nr; + break; + /* remove a v4l2loopback device (both capture and output) */ + case V4L2LOOPBACK_CTL_REMOVE: + ret = v4l2loopback_lookup((int)parm, &dev); + if (ret >= 0 && dev) { + ret = -EBUSY; + if (dev->open_count.counter > 0) + break; + v4l2_loopback_remove(dev); + ret = 0; + }; + break; + /* get information for a loopback device. + * this is mostly about limits (which cannot be queried directly with VIDIOC_G_FMT and friends + */ + case V4L2LOOPBACK_CTL_QUERY: + if (!parm) + break; + if ((ret = copy_from_user(&conf, (void *)parm, sizeof(conf))) < + 0) + break; + capture_nr = output_nr = conf.output_nr; +#ifdef SPLIT_DEVICES + capture_nr = conf.capture_nr; +#endif + device_nr = (output_nr < 0) ? capture_nr : output_nr; + MARK(); + /* get the device from either capture_nr or output_nr (whatever is valid) */ + if ((ret = v4l2loopback_lookup(device_nr, &dev)) < 0) + break; + MARK(); + /* if we got the device from output_nr and there is a valid capture_nr, + * make sure that both refer to the same device (or bail out) + */ + if ((device_nr != capture_nr) && (capture_nr >= 0) && + ((ret = v4l2loopback_lookup(capture_nr, 0)) < 0)) + break; + MARK(); + /* if otoh, we got the device from capture_nr and there is a valid output_nr, + * make sure that both refer to the same device (or bail out) + */ + if ((device_nr != output_nr) && (output_nr >= 0) && + ((ret = v4l2loopback_lookup(output_nr, 0)) < 0)) + break; + + /* v4l2_loopback_config identified a single device, so fetch the data */ + snprintf(conf.card_label, sizeof(conf.card_label), "%s", + dev->card_label); + + conf.output_nr = dev->vdev->num; +#ifdef SPLIT_DEVICES + conf.capture_nr = dev->vdev->num; +#endif + conf.min_width = dev->min_width; + conf.min_height = dev->min_height; + conf.max_width = dev->max_width; + conf.max_height = dev->max_height; + conf.announce_all_caps = dev->announce_all_caps; + conf.max_buffers = dev->buffer_count; + conf.max_openers = dev->max_openers; + conf.debug = debug; + MARK(); + if (copy_to_user((void *)parm, &conf, sizeof(conf))) { + ret = -EFAULT; + break; + } + ret = 0; + break; + } + + mutex_unlock(&v4l2loopback_ctl_mutex); + MARK(); + return ret; +} + +/* LINUX KERNEL */ + +static const struct file_operations v4l2loopback_ctl_fops = { + // clang-format off + .owner = THIS_MODULE, + .open = nonseekable_open, + .unlocked_ioctl = v4l2loopback_control_ioctl, + .compat_ioctl = v4l2loopback_control_ioctl, + .llseek = noop_llseek, + // clang-format on +}; + +static struct miscdevice v4l2loopback_misc = { + // clang-format off + .minor = MISC_DYNAMIC_MINOR, + .name = "v4l2loopback", + .fops = &v4l2loopback_ctl_fops, + // clang-format on +}; + +static const struct v4l2_file_operations v4l2_loopback_fops = { + // clang-format off + .owner = THIS_MODULE, + .open = v4l2_loopback_open, + .release = v4l2_loopback_close, + .read = v4l2_loopback_read, + .write = v4l2_loopback_write, + .poll = v4l2_loopback_poll, + .mmap = v4l2_loopback_mmap, + .unlocked_ioctl = video_ioctl2, + // clang-format on +}; + +static const struct v4l2_ioctl_ops v4l2_loopback_ioctl_ops = { + // clang-format off + .vidioc_querycap = &vidioc_querycap, + .vidioc_enum_framesizes = &vidioc_enum_framesizes, + .vidioc_enum_frameintervals = &vidioc_enum_frameintervals, + + .vidioc_enum_output = &vidioc_enum_output, + .vidioc_g_output = &vidioc_g_output, + .vidioc_s_output = &vidioc_s_output, + + .vidioc_enum_input = &vidioc_enum_input, + .vidioc_g_input = &vidioc_g_input, + .vidioc_s_input = &vidioc_s_input, + + .vidioc_enum_fmt_vid_cap = &vidioc_enum_fmt_cap, + .vidioc_g_fmt_vid_cap = &vidioc_g_fmt_cap, + .vidioc_s_fmt_vid_cap = &vidioc_s_fmt_cap, + .vidioc_try_fmt_vid_cap = &vidioc_try_fmt_cap, + + .vidioc_enum_fmt_vid_out = &vidioc_enum_fmt_out, + .vidioc_s_fmt_vid_out = &vidioc_s_fmt_out, + .vidioc_g_fmt_vid_out = &vidioc_g_fmt_out, + .vidioc_try_fmt_vid_out = &vidioc_try_fmt_out, + +#ifdef V4L2L_OVERLAY + .vidioc_s_fmt_vid_overlay = &vidioc_s_fmt_overlay, + .vidioc_g_fmt_vid_overlay = &vidioc_g_fmt_overlay, +#endif + +#ifdef V4L2LOOPBACK_WITH_STD + .vidioc_s_std = &vidioc_s_std, + .vidioc_g_std = &vidioc_g_std, + .vidioc_querystd = &vidioc_querystd, +#endif /* V4L2LOOPBACK_WITH_STD */ + + .vidioc_g_parm = &vidioc_g_parm, + .vidioc_s_parm = &vidioc_s_parm, + + .vidioc_reqbufs = &vidioc_reqbufs, + .vidioc_querybuf = &vidioc_querybuf, + .vidioc_qbuf = &vidioc_qbuf, + .vidioc_dqbuf = &vidioc_dqbuf, + + .vidioc_streamon = &vidioc_streamon, + .vidioc_streamoff = &vidioc_streamoff, + +#ifdef CONFIG_VIDEO_V4L1_COMPAT + .vidiocgmbuf = &vidiocgmbuf, +#endif + + .vidioc_subscribe_event = &vidioc_subscribe_event, + .vidioc_unsubscribe_event = &v4l2_event_unsubscribe, + // clang-format on +}; + +static int free_device_cb(int id, void *ptr, void *data) +{ + struct v4l2_loopback_device *dev = ptr; + v4l2_loopback_remove(dev); + return 0; +} +static void free_devices(void) +{ + idr_for_each(&v4l2loopback_index_idr, &free_device_cb, NULL); + idr_destroy(&v4l2loopback_index_idr); +} + +static int __init v4l2loopback_init_module(void) +{ + const u32 min_width = V4L2LOOPBACK_SIZE_MIN_WIDTH; + const u32 min_height = V4L2LOOPBACK_SIZE_MIN_HEIGHT; + int err; + int i; + MARK(); + + err = misc_register(&v4l2loopback_misc); + if (err < 0) + return err; + + if (devices < 0) { + devices = 1; + + /* try guessing the devices from the "video_nr" parameter */ + for (i = MAX_DEVICES - 1; i >= 0; i--) { + if (video_nr[i] >= 0) { + devices = i + 1; + break; + } + } + } + + if (devices > MAX_DEVICES) { + devices = MAX_DEVICES; + printk(KERN_INFO + "v4l2-loopback init() number of initial devices is " + "limited to: %d\n", + MAX_DEVICES); + } + + if (max_buffers > MAX_BUFFERS) { + max_buffers = MAX_BUFFERS; + printk(KERN_INFO + "v4l2-loopback init() number of buffers is limited " + "to: %d\n", + MAX_BUFFERS); + } + + if (max_openers < 0) { + printk(KERN_INFO + "v4l2-loopback init() allowing %d openers rather " + "than %d\n", + 2, max_openers); + max_openers = 2; + } + + if (max_width < min_width) { + max_width = V4L2LOOPBACK_SIZE_DEFAULT_MAX_WIDTH; + printk(KERN_INFO "v4l2-loopback init() using max_width %d\n", + max_width); + } + if (max_height < min_height) { + max_height = V4L2LOOPBACK_SIZE_DEFAULT_MAX_HEIGHT; + printk(KERN_INFO "v4l2-loopback init() using max_height %d\n", + max_height); + } + + for (i = 0; i < devices; i++) { + struct v4l2_loopback_config cfg = { + // clang-format off + .output_nr = video_nr[i], +#ifdef SPLIT_DEVICES + .capture_nr = video_nr[i], +#endif + .min_width = min_width, + .min_height = min_height, + .max_width = max_width, + .max_height = max_height, + .announce_all_caps = (!exclusive_caps[i]), + .max_buffers = max_buffers, + .max_openers = max_openers, + .debug = debug, + // clang-format on + }; + cfg.card_label[0] = 0; + if (card_label[i]) + snprintf(cfg.card_label, sizeof(cfg.card_label), "%s", + card_label[i]); + err = v4l2_loopback_add(&cfg, 0); + if (err) { + free_devices(); + goto error; + } + } + + dprintk("module installed\n"); + + printk(KERN_INFO "v4l2-loopback driver version %d.%d.%d%s loaded\n", + // clang-format off + (V4L2LOOPBACK_VERSION_CODE >> 16) & 0xff, + (V4L2LOOPBACK_VERSION_CODE >> 8) & 0xff, + (V4L2LOOPBACK_VERSION_CODE ) & 0xff, +#ifdef SNAPSHOT_VERSION + " (" __stringify(SNAPSHOT_VERSION) ")" +#else + "" +#endif + ); + // clang-format on + + return 0; +error: + misc_deregister(&v4l2loopback_misc); + return err; +} + +static void v4l2loopback_cleanup_module(void) +{ + MARK(); + /* unregister the device -> it deletes /dev/video* */ + free_devices(); + /* and get rid of /dev/v4l2loopback */ + misc_deregister(&v4l2loopback_misc); + dprintk("module removed\n"); +} + +MODULE_ALIAS_MISCDEV(MISC_DYNAMIC_MINOR); + +module_init(v4l2loopback_init_module); +module_exit(v4l2loopback_cleanup_module); diff --git a/drivers/media/v4l2-core/v4l2loopback.h b/drivers/media/v4l2-core/v4l2loopback.h new file mode 100644 index 000000000000..ec0be6cbe97d --- /dev/null +++ b/drivers/media/v4l2-core/v4l2loopback.h @@ -0,0 +1,98 @@ +/* SPDX-License-Identifier: GPL-2.0+ WITH Linux-syscall-note */ +/* + * v4l2loopback.h + * + * Written by IOhannes m zmölnig, 7/1/20. + * + * Copyright 2020 by IOhannes m zmölnig. Redistribution of this file is + * permitted under the GNU General Public License. + */ +#ifndef _V4L2LOOPBACK_H +#define _V4L2LOOPBACK_H + +#define V4L2LOOPBACK_VERSION_MAJOR 0 +#define V4L2LOOPBACK_VERSION_MINOR 14 +#define V4L2LOOPBACK_VERSION_BUGFIX 0 + +/* /dev/v4l2loopback interface */ + +struct v4l2_loopback_config { + /** + * the device-number (/dev/video) + * V4L2LOOPBACK_CTL_ADD: + * setting this to a value<0, will allocate an available one + * if nr>=0 and the device already exists, the ioctl will EEXIST + * if output_nr and capture_nr are the same, only a single device will be created + * NOTE: currently split-devices (where output_nr and capture_nr differ) + * are not implemented yet. + * until then, requesting different device-IDs will result in EINVAL. + * + * V4L2LOOPBACK_CTL_QUERY: + * either both output_nr and capture_nr must refer to the same loopback, + * or one (and only one) of them must be -1 + * + */ + int output_nr; + int unused; /*capture_nr;*/ + + /** + * a nice name for your device + * if (*card_label)==0, an automatic name is assigned + */ + char card_label[32]; + + /** + * allowed frame size + * if too low, default values are used + */ + unsigned int min_width; + unsigned int max_width; + unsigned int min_height; + unsigned int max_height; + + /** + * number of buffers to allocate for the queue + * if set to <=0, default values are used + */ + int max_buffers; + + /** + * how many consumers are allowed to open this device concurrently + * if set to <=0, default values are used + */ + int max_openers; + + /** + * set the debugging level for this device + */ + int debug; + + /** + * whether to announce OUTPUT/CAPTURE capabilities exclusively + * for this device or not + * (!exclusive_caps) + * NOTE: this is going to be removed once separate output/capture + * devices are implemented + */ + int announce_all_caps; +}; + +/* a pointer to a (struct v4l2_loopback_config) that has all values you wish to impose on the + * to-be-created device set. + * if the ptr is NULL, a new device is created with default values at the driver's discretion. + * + * returns the device_nr of the OUTPUT device (which can be used with V4L2LOOPBACK_CTL_QUERY, + * to get more information on the device) + */ +#define V4L2LOOPBACK_CTL_ADD 0x4C80 + +/* a pointer to a (struct v4l2_loopback_config) that has output_nr and/or capture_nr set + * (the two values must either refer to video-devices associated with the same loopback device + * or exactly one of them must be <0 + */ +#define V4L2LOOPBACK_CTL_QUERY 0x4C82 + +/* the device-number (either CAPTURE or OUTPUT) associated with the loopback-device */ +#define V4L2LOOPBACK_CTL_REMOVE 0x4C81 + +#endif /* _V4L2LOOPBACK_H */ diff --git a/drivers/media/v4l2-core/v4l2loopback_formats.h b/drivers/media/v4l2-core/v4l2loopback_formats.h new file mode 100644 index 000000000000..d855a3796554 --- /dev/null +++ b/drivers/media/v4l2-core/v4l2loopback_formats.h @@ -0,0 +1,445 @@ +static const struct v4l2l_format formats[] = { +#ifndef V4L2_PIX_FMT_VP9 +#define V4L2_PIX_FMT_VP9 v4l2_fourcc('V', 'P', '9', '0') +#endif +#ifndef V4L2_PIX_FMT_HEVC +#define V4L2_PIX_FMT_HEVC v4l2_fourcc('H', 'E', 'V', 'C') +#endif + + /* here come the packed formats */ + { + .name = "32 bpp RGB, le", + .fourcc = V4L2_PIX_FMT_BGR32, + .depth = 32, + .flags = 0, + }, + { + .name = "32 bpp RGB, be", + .fourcc = V4L2_PIX_FMT_RGB32, + .depth = 32, + .flags = 0, + }, + { + .name = "24 bpp RGB, le", + .fourcc = V4L2_PIX_FMT_BGR24, + .depth = 24, + .flags = 0, + }, + { + .name = "24 bpp RGB, be", + .fourcc = V4L2_PIX_FMT_RGB24, + .depth = 24, + .flags = 0, + }, +#ifdef V4L2_PIX_FMT_ABGR32 + { + .name = "32 bpp RGBA, le", + .fourcc = V4L2_PIX_FMT_ABGR32, + .depth = 32, + .flags = 0, + }, +#endif +#ifdef V4L2_PIX_FMT_RGBA32 + { + .name = "32 bpp RGBA", + .fourcc = V4L2_PIX_FMT_RGBA32, + .depth = 32, + .flags = 0, + }, +#endif +#ifdef V4L2_PIX_FMT_RGB332 + { + .name = "8 bpp RGB-3-3-2", + .fourcc = V4L2_PIX_FMT_RGB332, + .depth = 8, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_RGB332 */ +#ifdef V4L2_PIX_FMT_RGB444 + { + .name = "16 bpp RGB (xxxxrrrr ggggbbbb)", + .fourcc = V4L2_PIX_FMT_RGB444, + .depth = 16, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_RGB444 */ +#ifdef V4L2_PIX_FMT_RGB555 + { + .name = "16 bpp RGB-5-5-5", + .fourcc = V4L2_PIX_FMT_RGB555, + .depth = 16, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_RGB555 */ +#ifdef V4L2_PIX_FMT_RGB565 + { + .name = "16 bpp RGB-5-6-5", + .fourcc = V4L2_PIX_FMT_RGB565, + .depth = 16, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_RGB565 */ +#ifdef V4L2_PIX_FMT_RGB555X + { + .name = "16 bpp RGB-5-5-5 BE", + .fourcc = V4L2_PIX_FMT_RGB555X, + .depth = 16, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_RGB555X */ +#ifdef V4L2_PIX_FMT_RGB565X + { + .name = "16 bpp RGB-5-6-5 BE", + .fourcc = V4L2_PIX_FMT_RGB565X, + .depth = 16, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_RGB565X */ +#ifdef V4L2_PIX_FMT_BGR666 + { + .name = "18 bpp BGR-6-6-6", + .fourcc = V4L2_PIX_FMT_BGR666, + .depth = 18, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_BGR666 */ + { + .name = "4:2:2, packed, YUYV", + .fourcc = V4L2_PIX_FMT_YUYV, + .depth = 16, + .flags = 0, + }, + { + .name = "4:2:2, packed, UYVY", + .fourcc = V4L2_PIX_FMT_UYVY, + .depth = 16, + .flags = 0, + }, +#ifdef V4L2_PIX_FMT_YVYU + { + .name = "4:2:2, packed YVYU", + .fourcc = V4L2_PIX_FMT_YVYU, + .depth = 16, + .flags = 0, + }, +#endif +#ifdef V4L2_PIX_FMT_VYUY + { + .name = "4:2:2, packed VYUY", + .fourcc = V4L2_PIX_FMT_VYUY, + .depth = 16, + .flags = 0, + }, +#endif + { + .name = "4:2:2, packed YYUV", + .fourcc = V4L2_PIX_FMT_YYUV, + .depth = 16, + .flags = 0, + }, + { + .name = "YUV-8-8-8-8", + .fourcc = V4L2_PIX_FMT_YUV32, + .depth = 32, + .flags = 0, + }, + { + .name = "8 bpp, Greyscale", + .fourcc = V4L2_PIX_FMT_GREY, + .depth = 8, + .flags = 0, + }, +#ifdef V4L2_PIX_FMT_Y4 + { + .name = "4 bpp Greyscale", + .fourcc = V4L2_PIX_FMT_Y4, + .depth = 4, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_Y4 */ +#ifdef V4L2_PIX_FMT_Y6 + { + .name = "6 bpp Greyscale", + .fourcc = V4L2_PIX_FMT_Y6, + .depth = 6, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_Y6 */ +#ifdef V4L2_PIX_FMT_Y10 + { + .name = "10 bpp Greyscale", + .fourcc = V4L2_PIX_FMT_Y10, + .depth = 10, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_Y10 */ +#ifdef V4L2_PIX_FMT_Y12 + { + .name = "12 bpp Greyscale", + .fourcc = V4L2_PIX_FMT_Y12, + .depth = 12, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_Y12 */ + { + .name = "16 bpp, Greyscale", + .fourcc = V4L2_PIX_FMT_Y16, + .depth = 16, + .flags = 0, + }, +#ifdef V4L2_PIX_FMT_YUV444 + { + .name = "16 bpp xxxxyyyy uuuuvvvv", + .fourcc = V4L2_PIX_FMT_YUV444, + .depth = 16, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_YUV444 */ +#ifdef V4L2_PIX_FMT_YUV555 + { + .name = "16 bpp YUV-5-5-5", + .fourcc = V4L2_PIX_FMT_YUV555, + .depth = 16, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_YUV555 */ +#ifdef V4L2_PIX_FMT_YUV565 + { + .name = "16 bpp YUV-5-6-5", + .fourcc = V4L2_PIX_FMT_YUV565, + .depth = 16, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_YUV565 */ + +/* bayer formats */ +#ifdef V4L2_PIX_FMT_SRGGB8 + { + .name = "Bayer RGGB 8bit", + .fourcc = V4L2_PIX_FMT_SRGGB8, + .depth = 8, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_SRGGB8 */ +#ifdef V4L2_PIX_FMT_SGRBG8 + { + .name = "Bayer GRBG 8bit", + .fourcc = V4L2_PIX_FMT_SGRBG8, + .depth = 8, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_SGRBG8 */ +#ifdef V4L2_PIX_FMT_SGBRG8 + { + .name = "Bayer GBRG 8bit", + .fourcc = V4L2_PIX_FMT_SGBRG8, + .depth = 8, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_SGBRG8 */ +#ifdef V4L2_PIX_FMT_SBGGR8 + { + .name = "Bayer BA81 8bit", + .fourcc = V4L2_PIX_FMT_SBGGR8, + .depth = 8, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_SBGGR8 */ + + /* here come the planar formats */ + { + .name = "4:1:0, planar, Y-Cr-Cb", + .fourcc = V4L2_PIX_FMT_YVU410, + .depth = 9, + .flags = FORMAT_FLAGS_PLANAR, + }, + { + .name = "4:2:0, planar, Y-Cr-Cb", + .fourcc = V4L2_PIX_FMT_YVU420, + .depth = 12, + .flags = FORMAT_FLAGS_PLANAR, + }, + { + .name = "4:1:0, planar, Y-Cb-Cr", + .fourcc = V4L2_PIX_FMT_YUV410, + .depth = 9, + .flags = FORMAT_FLAGS_PLANAR, + }, + { + .name = "4:2:0, planar, Y-Cb-Cr", + .fourcc = V4L2_PIX_FMT_YUV420, + .depth = 12, + .flags = FORMAT_FLAGS_PLANAR, + }, +#ifdef V4L2_PIX_FMT_YUV422P + { + .name = "16 bpp YVU422 planar", + .fourcc = V4L2_PIX_FMT_YUV422P, + .depth = 16, + .flags = FORMAT_FLAGS_PLANAR, + }, +#endif /* V4L2_PIX_FMT_YUV422P */ +#ifdef V4L2_PIX_FMT_YUV411P + { + .name = "16 bpp YVU411 planar", + .fourcc = V4L2_PIX_FMT_YUV411P, + .depth = 16, + .flags = FORMAT_FLAGS_PLANAR, + }, +#endif /* V4L2_PIX_FMT_YUV411P */ +#ifdef V4L2_PIX_FMT_Y41P + { + .name = "12 bpp YUV 4:1:1", + .fourcc = V4L2_PIX_FMT_Y41P, + .depth = 12, + .flags = FORMAT_FLAGS_PLANAR, + }, +#endif /* V4L2_PIX_FMT_Y41P */ +#ifdef V4L2_PIX_FMT_NV12 + { + .name = "12 bpp Y/CbCr 4:2:0 ", + .fourcc = V4L2_PIX_FMT_NV12, + .depth = 12, + .flags = FORMAT_FLAGS_PLANAR, + }, +#endif /* V4L2_PIX_FMT_NV12 */ + +/* here come the compressed formats */ + +#ifdef V4L2_PIX_FMT_MJPEG + { + .name = "Motion-JPEG", + .fourcc = V4L2_PIX_FMT_MJPEG, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_MJPEG */ +#ifdef V4L2_PIX_FMT_JPEG + { + .name = "JFIF JPEG", + .fourcc = V4L2_PIX_FMT_JPEG, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_JPEG */ +#ifdef V4L2_PIX_FMT_DV + { + .name = "DV1394", + .fourcc = V4L2_PIX_FMT_DV, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_DV */ +#ifdef V4L2_PIX_FMT_MPEG + { + .name = "MPEG-1/2/4 Multiplexed", + .fourcc = V4L2_PIX_FMT_MPEG, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_MPEG */ +#ifdef V4L2_PIX_FMT_H264 + { + .name = "H264 with start codes", + .fourcc = V4L2_PIX_FMT_H264, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_H264 */ +#ifdef V4L2_PIX_FMT_H264_NO_SC + { + .name = "H264 without start codes", + .fourcc = V4L2_PIX_FMT_H264_NO_SC, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_H264_NO_SC */ +#ifdef V4L2_PIX_FMT_H264_MVC + { + .name = "H264 MVC", + .fourcc = V4L2_PIX_FMT_H264_MVC, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_H264_MVC */ +#ifdef V4L2_PIX_FMT_H263 + { + .name = "H263", + .fourcc = V4L2_PIX_FMT_H263, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_H263 */ +#ifdef V4L2_PIX_FMT_MPEG1 + { + .name = "MPEG-1 ES", + .fourcc = V4L2_PIX_FMT_MPEG1, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_MPEG1 */ +#ifdef V4L2_PIX_FMT_MPEG2 + { + .name = "MPEG-2 ES", + .fourcc = V4L2_PIX_FMT_MPEG2, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_MPEG2 */ +#ifdef V4L2_PIX_FMT_MPEG4 + { + .name = "MPEG-4 part 2 ES", + .fourcc = V4L2_PIX_FMT_MPEG4, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_MPEG4 */ +#ifdef V4L2_PIX_FMT_XVID + { + .name = "Xvid", + .fourcc = V4L2_PIX_FMT_XVID, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_XVID */ +#ifdef V4L2_PIX_FMT_VC1_ANNEX_G + { + .name = "SMPTE 421M Annex G compliant stream", + .fourcc = V4L2_PIX_FMT_VC1_ANNEX_G, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_VC1_ANNEX_G */ +#ifdef V4L2_PIX_FMT_VC1_ANNEX_L + { + .name = "SMPTE 421M Annex L compliant stream", + .fourcc = V4L2_PIX_FMT_VC1_ANNEX_L, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_VC1_ANNEX_L */ +#ifdef V4L2_PIX_FMT_VP8 + { + .name = "VP8", + .fourcc = V4L2_PIX_FMT_VP8, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_VP8 */ +#ifdef V4L2_PIX_FMT_VP9 + { + .name = "VP9", + .fourcc = V4L2_PIX_FMT_VP9, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_VP9 */ +#ifdef V4L2_PIX_FMT_HEVC + { + .name = "HEVC", + .fourcc = V4L2_PIX_FMT_HEVC, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_HEVC */ +}; diff --git a/drivers/pci/controller/Makefile b/drivers/pci/controller/Makefile index 038ccbd9e3ba..de5e4f5145af 100644 --- a/drivers/pci/controller/Makefile +++ b/drivers/pci/controller/Makefile @@ -1,4 +1,10 @@ # SPDX-License-Identifier: GPL-2.0 +ifdef CONFIG_X86_64 +ifdef CONFIG_SATA_AHCI +obj-y += intel-nvme-remap.o +endif +endif + obj-$(CONFIG_PCIE_CADENCE) += cadence/ obj-$(CONFIG_PCI_FTPCI100) += pci-ftpci100.o obj-$(CONFIG_PCI_IXP4XX) += pci-ixp4xx.o diff --git a/drivers/pci/controller/intel-nvme-remap.c b/drivers/pci/controller/intel-nvme-remap.c new file mode 100644 index 000000000000..e105e6f5cc91 --- /dev/null +++ b/drivers/pci/controller/intel-nvme-remap.c @@ -0,0 +1,462 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Intel remapped NVMe device support. + * + * Copyright (c) 2019 Endless Mobile, Inc. + * Author: Daniel Drake + * + * Some products ship by default with the SATA controller in "RAID" or + * "Intel RST Premium With Intel Optane System Acceleration" mode. Under this + * mode, which we refer to as "remapped NVMe" mode, any installed NVMe + * devices disappear from the PCI bus, and instead their I/O memory becomes + * available within the AHCI device BARs. + * + * This scheme is understood to be a way of avoiding usage of the standard + * Windows NVMe driver under that OS, instead mandating usage of Intel's + * driver instead, which has better power management, and presumably offers + * some RAID/disk-caching solutions too. + * + * Here in this driver, we support the remapped NVMe mode by claiming the + * AHCI device and creating a fake PCIe root port. On the new bus, the + * original AHCI device is exposed with only minor tweaks. Then, fake PCI + * devices corresponding to the remapped NVMe devices are created. The usual + * ahci and nvme drivers are then expected to bind to these devices and + * operate as normal. + * + * The PCI configuration space for the NVMe devices is completely + * unavailable, so we fake a minimal one and hope for the best. + * + * Interrupts are shared between the AHCI and NVMe devices. For simplicity, + * we only support the legacy interrupt here, although MSI support + * could potentially be added later. + */ + +#define MODULE_NAME "intel-nvme-remap" + +#include +#include +#include +#include +#include + +#define AHCI_PCI_BAR_STANDARD 5 + +struct nvme_remap_dev { + struct pci_dev *dev; /* AHCI device */ + struct pci_bus *bus; /* our fake PCI bus */ + struct pci_sysdata sysdata; + int irq_base; /* our fake interrupts */ + + /* + * When we detect an all-ones write to a BAR register, this flag + * is set, so that we return the BAR size on the next read (a + * standard PCI behaviour). + * This includes the assumption that an all-ones BAR write is + * immediately followed by a read of the same register. + */ + bool bar_sizing; + + /* + * Resources copied from the AHCI device, to be regarded as + * resources on our fake bus. + */ + struct resource ahci_resources[PCI_NUM_RESOURCES]; + + /* Resources corresponding to the NVMe devices. */ + struct resource remapped_dev_mem[AHCI_MAX_REMAP]; + + /* Number of remapped NVMe devices found. */ + int num_remapped_devices; +}; + +static inline struct nvme_remap_dev *nrdev_from_bus(struct pci_bus *bus) +{ + return container_of(bus->sysdata, struct nvme_remap_dev, sysdata); +} + + +/******** PCI configuration space **********/ + +/* + * Helper macros for tweaking returned contents of PCI configuration space. + * + * value contains len bytes of data read from reg. + * If fixup_reg is included in that range, fix up the contents of that + * register to fixed_value. + */ +#define NR_FIX8(fixup_reg, fixed_value) do { \ + if (reg <= fixup_reg && fixup_reg < reg + len) \ + ((u8 *) value)[fixup_reg - reg] = (u8) (fixed_value); \ + } while (0) + +#define NR_FIX16(fixup_reg, fixed_value) do { \ + NR_FIX8(fixup_reg, fixed_value); \ + NR_FIX8(fixup_reg + 1, fixed_value >> 8); \ + } while (0) + +#define NR_FIX24(fixup_reg, fixed_value) do { \ + NR_FIX8(fixup_reg, fixed_value); \ + NR_FIX8(fixup_reg + 1, fixed_value >> 8); \ + NR_FIX8(fixup_reg + 2, fixed_value >> 16); \ + } while (0) + +#define NR_FIX32(fixup_reg, fixed_value) do { \ + NR_FIX16(fixup_reg, (u16) fixed_value); \ + NR_FIX16(fixup_reg + 2, fixed_value >> 16); \ + } while (0) + +/* + * Read PCI config space of the slot 0 (AHCI) device. + * We pass through the read request to the underlying device, but + * tweak the results in some cases. + */ +static int nvme_remap_pci_read_slot0(struct pci_bus *bus, int reg, + int len, u32 *value) +{ + struct nvme_remap_dev *nrdev = nrdev_from_bus(bus); + struct pci_bus *ahci_dev_bus = nrdev->dev->bus; + int ret; + + ret = ahci_dev_bus->ops->read(ahci_dev_bus, nrdev->dev->devfn, + reg, len, value); + if (ret) + return ret; + + /* + * Adjust the device class, to prevent this driver from attempting to + * additionally probe the device we're simulating here. + */ + NR_FIX24(PCI_CLASS_PROG, PCI_CLASS_STORAGE_SATA_AHCI); + + /* + * Unset interrupt pin, otherwise ACPI tries to find routing + * info for our virtual IRQ, fails, and complains. + */ + NR_FIX8(PCI_INTERRUPT_PIN, 0); + + /* + * Truncate the AHCI BAR to not include the region that covers the + * hidden devices. This will cause the ahci driver to successfully + * probe th new device (instead of handing it over to this driver). + */ + if (nrdev->bar_sizing) { + NR_FIX32(PCI_BASE_ADDRESS_5, ~(SZ_16K - 1)); + nrdev->bar_sizing = false; + } + + return PCIBIOS_SUCCESSFUL; +} + +/* + * Read PCI config space of a remapped device. + * Since the original PCI config space is inaccessible, we provide a minimal, + * fake config space instead. + */ +static int nvme_remap_pci_read_remapped(struct pci_bus *bus, unsigned int port, + int reg, int len, u32 *value) +{ + struct nvme_remap_dev *nrdev = nrdev_from_bus(bus); + struct resource *remapped_mem; + + if (port > nrdev->num_remapped_devices) + return PCIBIOS_DEVICE_NOT_FOUND; + + *value = 0; + remapped_mem = &nrdev->remapped_dev_mem[port - 1]; + + /* Set a Vendor ID, otherwise Linux assumes no device is present */ + NR_FIX16(PCI_VENDOR_ID, PCI_VENDOR_ID_INTEL); + + /* Always appear on & bus mastering */ + NR_FIX16(PCI_COMMAND, PCI_COMMAND_MEMORY | PCI_COMMAND_MASTER); + + /* Set class so that nvme driver probes us */ + NR_FIX24(PCI_CLASS_PROG, PCI_CLASS_STORAGE_EXPRESS); + + if (nrdev->bar_sizing) { + NR_FIX32(PCI_BASE_ADDRESS_0, + ~(resource_size(remapped_mem) - 1)); + nrdev->bar_sizing = false; + } else { + resource_size_t mem_start = remapped_mem->start; + + mem_start |= PCI_BASE_ADDRESS_MEM_TYPE_64; + NR_FIX32(PCI_BASE_ADDRESS_0, mem_start); + mem_start >>= 32; + NR_FIX32(PCI_BASE_ADDRESS_1, mem_start); + } + + return PCIBIOS_SUCCESSFUL; +} + +/* Read PCI configuration space. */ +static int nvme_remap_pci_read(struct pci_bus *bus, unsigned int devfn, + int reg, int len, u32 *value) +{ + if (PCI_SLOT(devfn) == 0) + return nvme_remap_pci_read_slot0(bus, reg, len, value); + else + return nvme_remap_pci_read_remapped(bus, PCI_SLOT(devfn), + reg, len, value); +} + +/* + * Write PCI config space of the slot 0 (AHCI) device. + * Apart from the special case of BAR sizing, we disable all writes. + * Otherwise, the ahci driver could make changes (e.g. unset PCI bus master) + * that would affect the operation of the NVMe devices. + */ +static int nvme_remap_pci_write_slot0(struct pci_bus *bus, int reg, + int len, u32 value) +{ + struct nvme_remap_dev *nrdev = nrdev_from_bus(bus); + struct pci_bus *ahci_dev_bus = nrdev->dev->bus; + + if (reg >= PCI_BASE_ADDRESS_0 && reg <= PCI_BASE_ADDRESS_5) { + /* + * Writing all-ones to a BAR means that the size of the + * memory region is being checked. Flag this so that we can + * reply with an appropriate size on the next read. + */ + if (value == ~0) + nrdev->bar_sizing = true; + + return ahci_dev_bus->ops->write(ahci_dev_bus, + nrdev->dev->devfn, + reg, len, value); + } + + return PCIBIOS_SET_FAILED; +} + +/* + * Write PCI config space of a remapped device. + * Since the original PCI config space is inaccessible, we reject all + * writes, except for the special case of BAR probing. + */ +static int nvme_remap_pci_write_remapped(struct pci_bus *bus, + unsigned int port, + int reg, int len, u32 value) +{ + struct nvme_remap_dev *nrdev = nrdev_from_bus(bus); + + if (port > nrdev->num_remapped_devices) + return PCIBIOS_DEVICE_NOT_FOUND; + + /* + * Writing all-ones to a BAR means that the size of the memory + * region is being checked. Flag this so that we can reply with + * an appropriate size on the next read. + */ + if (value == ~0 && reg >= PCI_BASE_ADDRESS_0 + && reg <= PCI_BASE_ADDRESS_5) { + nrdev->bar_sizing = true; + return PCIBIOS_SUCCESSFUL; + } + + return PCIBIOS_SET_FAILED; +} + +/* Write PCI configuration space. */ +static int nvme_remap_pci_write(struct pci_bus *bus, unsigned int devfn, + int reg, int len, u32 value) +{ + if (PCI_SLOT(devfn) == 0) + return nvme_remap_pci_write_slot0(bus, reg, len, value); + else + return nvme_remap_pci_write_remapped(bus, PCI_SLOT(devfn), + reg, len, value); +} + +static struct pci_ops nvme_remap_pci_ops = { + .read = nvme_remap_pci_read, + .write = nvme_remap_pci_write, +}; + + +/******** Initialization & exit **********/ + +/* + * Find a PCI domain ID to use for our fake bus. + * Start at 0x10000 to not clash with ACPI _SEG domains (16 bits). + */ +static int find_free_domain(void) +{ + int domain = 0xffff; + struct pci_bus *bus = NULL; + + while ((bus = pci_find_next_bus(bus)) != NULL) + domain = max_t(int, domain, pci_domain_nr(bus)); + + return domain + 1; +} + +static int find_remapped_devices(struct nvme_remap_dev *nrdev, + struct list_head *resources) +{ + void __iomem *mmio; + int i, count = 0; + u32 cap; + + mmio = pcim_iomap(nrdev->dev, AHCI_PCI_BAR_STANDARD, + pci_resource_len(nrdev->dev, + AHCI_PCI_BAR_STANDARD)); + if (!mmio) + return -ENODEV; + + /* Check if this device might have remapped nvme devices. */ + if (pci_resource_len(nrdev->dev, AHCI_PCI_BAR_STANDARD) < SZ_512K || + !(readl(mmio + AHCI_VSCAP) & 1)) + return -ENODEV; + + cap = readq(mmio + AHCI_REMAP_CAP); + for (i = AHCI_MAX_REMAP-1; i >= 0; i--) { + struct resource *remapped_mem; + + if ((cap & (1 << i)) == 0) + continue; + if (readl(mmio + ahci_remap_dcc(i)) + != PCI_CLASS_STORAGE_EXPRESS) + continue; + + /* We've found a remapped device */ + remapped_mem = &nrdev->remapped_dev_mem[count++]; + remapped_mem->start = + pci_resource_start(nrdev->dev, AHCI_PCI_BAR_STANDARD) + + ahci_remap_base(i); + remapped_mem->end = remapped_mem->start + + AHCI_REMAP_N_SIZE - 1; + remapped_mem->flags = IORESOURCE_MEM | IORESOURCE_PCI_FIXED; + pci_add_resource(resources, remapped_mem); + } + + pcim_iounmap(nrdev->dev, mmio); + + if (count == 0) + return -ENODEV; + + nrdev->num_remapped_devices = count; + dev_info(&nrdev->dev->dev, "Found %d remapped NVMe devices\n", + nrdev->num_remapped_devices); + return 0; +} + +static void nvme_remap_remove_root_bus(void *data) +{ + struct pci_bus *bus = data; + + pci_stop_root_bus(bus); + pci_remove_root_bus(bus); +} + +static int nvme_remap_probe(struct pci_dev *dev, + const struct pci_device_id *id) +{ + struct nvme_remap_dev *nrdev; + LIST_HEAD(resources); + int i; + int ret; + struct pci_dev *child; + + nrdev = devm_kzalloc(&dev->dev, sizeof(*nrdev), GFP_KERNEL); + nrdev->sysdata.domain = find_free_domain(); + nrdev->sysdata.nvme_remap_dev = dev; + nrdev->dev = dev; + pci_set_drvdata(dev, nrdev); + + ret = pcim_enable_device(dev); + if (ret < 0) + return ret; + + pci_set_master(dev); + + ret = find_remapped_devices(nrdev, &resources); + if (ret) + return ret; + + /* Add resources from the original AHCI device */ + for (i = 0; i < PCI_NUM_RESOURCES; i++) { + struct resource *res = &dev->resource[i]; + + if (res->start) { + struct resource *nr_res = &nrdev->ahci_resources[i]; + + nr_res->start = res->start; + nr_res->end = res->end; + nr_res->flags = res->flags; + pci_add_resource(&resources, nr_res); + } + } + + /* Create virtual interrupts */ + nrdev->irq_base = devm_irq_alloc_descs(&dev->dev, -1, 0, + nrdev->num_remapped_devices + 1, + 0); + if (nrdev->irq_base < 0) + return nrdev->irq_base; + + /* Create and populate PCI bus */ + nrdev->bus = pci_create_root_bus(&dev->dev, 0, &nvme_remap_pci_ops, + &nrdev->sysdata, &resources); + if (!nrdev->bus) + return -ENODEV; + + if (devm_add_action_or_reset(&dev->dev, nvme_remap_remove_root_bus, + nrdev->bus)) + return -ENOMEM; + + /* We don't support sharing MSI interrupts between these devices */ + nrdev->bus->bus_flags |= PCI_BUS_FLAGS_NO_MSI; + + pci_scan_child_bus(nrdev->bus); + + list_for_each_entry(child, &nrdev->bus->devices, bus_list) { + /* + * Prevent PCI core from trying to move memory BARs around. + * The hidden NVMe devices are at fixed locations. + */ + for (i = 0; i < PCI_NUM_RESOURCES; i++) { + struct resource *res = &child->resource[i]; + + if (res->flags & IORESOURCE_MEM) + res->flags |= IORESOURCE_PCI_FIXED; + } + + /* Share the legacy IRQ between all devices */ + child->irq = dev->irq; + } + + pci_assign_unassigned_bus_resources(nrdev->bus); + pci_bus_add_devices(nrdev->bus); + + return 0; +} + +static const struct pci_device_id nvme_remap_ids[] = { + /* + * Match all Intel RAID controllers. + * + * There's overlap here with the set of devices detected by the ahci + * driver, but ahci will only successfully probe when there + * *aren't* any remapped NVMe devices, and this driver will only + * successfully probe when there *are* remapped NVMe devices that + * need handling. + */ + { + PCI_VDEVICE(INTEL, PCI_ANY_ID), + .class = PCI_CLASS_STORAGE_RAID << 8, + .class_mask = 0xffffff00, + }, + {0,} +}; +MODULE_DEVICE_TABLE(pci, nvme_remap_ids); + +static struct pci_driver nvme_remap_drv = { + .name = MODULE_NAME, + .id_table = nvme_remap_ids, + .probe = nvme_remap_probe, +}; +module_pci_driver(nvme_remap_drv); + +MODULE_AUTHOR("Daniel Drake "); +MODULE_LICENSE("GPL v2"); diff --git a/drivers/pci/quirks.c b/drivers/pci/quirks.c index 82b21e34c545..fa167d29821f 100644 --- a/drivers/pci/quirks.c +++ b/drivers/pci/quirks.c @@ -3747,6 +3747,106 @@ static void quirk_no_bus_reset(struct pci_dev *dev) dev->dev_flags |= PCI_DEV_FLAGS_NO_BUS_RESET; } +static bool acs_on_downstream; +static bool acs_on_multifunction; + +#define NUM_ACS_IDS 16 +struct acs_on_id { + unsigned short vendor; + unsigned short device; +}; +static struct acs_on_id acs_on_ids[NUM_ACS_IDS]; +static u8 max_acs_id; + +static __init int pcie_acs_override_setup(char *p) +{ + if (!p) + return -EINVAL; + + while (*p) { + if (!strncmp(p, "downstream", 10)) + acs_on_downstream = true; + if (!strncmp(p, "multifunction", 13)) + acs_on_multifunction = true; + if (!strncmp(p, "id:", 3)) { + char opt[5]; + int ret; + long val; + + if (max_acs_id >= NUM_ACS_IDS - 1) { + pr_warn("Out of PCIe ACS override slots (%d)\n", + NUM_ACS_IDS); + goto next; + } + + p += 3; + snprintf(opt, 5, "%s", p); + ret = kstrtol(opt, 16, &val); + if (ret) { + pr_warn("PCIe ACS ID parse error %d\n", ret); + goto next; + } + acs_on_ids[max_acs_id].vendor = val; + + p += strcspn(p, ":"); + if (*p != ':') { + pr_warn("PCIe ACS invalid ID\n"); + goto next; + } + + p++; + snprintf(opt, 5, "%s", p); + ret = kstrtol(opt, 16, &val); + if (ret) { + pr_warn("PCIe ACS ID parse error %d\n", ret); + goto next; + } + acs_on_ids[max_acs_id].device = val; + max_acs_id++; + } +next: + p += strcspn(p, ","); + if (*p == ',') + p++; + } + + if (acs_on_downstream || acs_on_multifunction || max_acs_id) + pr_warn("Warning: PCIe ACS overrides enabled; This may allow non-IOMMU protected peer-to-peer DMA\n"); + + return 0; +} +early_param("pcie_acs_override", pcie_acs_override_setup); + +static int pcie_acs_overrides(struct pci_dev *dev, u16 acs_flags) +{ + int i; + + /* Never override ACS for legacy devices or devices with ACS caps */ + if (!pci_is_pcie(dev) || + pci_find_ext_capability(dev, PCI_EXT_CAP_ID_ACS)) + return -ENOTTY; + + for (i = 0; i < max_acs_id; i++) + if (acs_on_ids[i].vendor == dev->vendor && + acs_on_ids[i].device == dev->device) + return 1; + + switch (pci_pcie_type(dev)) { + case PCI_EXP_TYPE_DOWNSTREAM: + case PCI_EXP_TYPE_ROOT_PORT: + if (acs_on_downstream) + return 1; + break; + case PCI_EXP_TYPE_ENDPOINT: + case PCI_EXP_TYPE_UPSTREAM: + case PCI_EXP_TYPE_LEG_END: + case PCI_EXP_TYPE_RC_END: + if (acs_on_multifunction && dev->multifunction) + return 1; + } + + return -ENOTTY; +} /* * Some NVIDIA GPU devices do not work with bus reset, SBR needs to be * prevented for those affected devices. @@ -5171,6 +5271,7 @@ static const struct pci_dev_acs_enabled { { PCI_VENDOR_ID_ZHAOXIN, PCI_ANY_ID, pci_quirk_zhaoxin_pcie_ports_acs }, /* Wangxun nics */ { PCI_VENDOR_ID_WANGXUN, PCI_ANY_ID, pci_quirk_wangxun_nic_acs }, + { PCI_ANY_ID, PCI_ANY_ID, pcie_acs_overrides }, { 0 } }; diff --git a/drivers/scsi/Kconfig b/drivers/scsi/Kconfig index 37c24ffea65c..bd52d1e081b7 100644 --- a/drivers/scsi/Kconfig +++ b/drivers/scsi/Kconfig @@ -1522,4 +1522,6 @@ endif # SCSI_LOWLEVEL source "drivers/scsi/device_handler/Kconfig" +source "drivers/scsi/vhba/Kconfig" + endmenu diff --git a/drivers/scsi/Makefile b/drivers/scsi/Makefile index 1313ddf2fd1a..5942e8f79159 100644 --- a/drivers/scsi/Makefile +++ b/drivers/scsi/Makefile @@ -153,6 +153,7 @@ obj-$(CONFIG_CHR_DEV_SCH) += ch.o obj-$(CONFIG_SCSI_ENCLOSURE) += ses.o obj-$(CONFIG_SCSI_HISI_SAS) += hisi_sas/ +obj-$(CONFIG_VHBA) += vhba/ # This goes last, so that "real" scsi devices probe earlier obj-$(CONFIG_SCSI_DEBUG) += scsi_debug.o diff --git a/drivers/scsi/vhba/Kconfig b/drivers/scsi/vhba/Kconfig new file mode 100644 index 000000000000..e70a381fe3df --- /dev/null +++ b/drivers/scsi/vhba/Kconfig @@ -0,0 +1,9 @@ +config VHBA + tristate "Virtual (SCSI) Host Bus Adapter" + depends on SCSI + help + This is the in-kernel part of CDEmu, a CD/DVD-ROM device + emulator. + + This driver can also be built as a module. If so, the module + will be called vhba. diff --git a/drivers/scsi/vhba/Makefile b/drivers/scsi/vhba/Makefile new file mode 100644 index 000000000000..2d7524b66199 --- /dev/null +++ b/drivers/scsi/vhba/Makefile @@ -0,0 +1,4 @@ +VHBA_VERSION := 20240917 + +obj-$(CONFIG_VHBA) += vhba.o +ccflags-y := -DVHBA_VERSION=\"$(VHBA_VERSION)\" -Werror diff --git a/drivers/scsi/vhba/vhba.c b/drivers/scsi/vhba/vhba.c new file mode 100644 index 000000000000..878a3be0ba2b --- /dev/null +++ b/drivers/scsi/vhba/vhba.c @@ -0,0 +1,1132 @@ +/* + * vhba.c + * + * Copyright (C) 2007-2012 Chia-I Wu + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#define pr_fmt(fmt) "vhba: " fmt + +#include + +#include +#include +#include +#include +#if LINUX_VERSION_CODE >= KERNEL_VERSION(4, 11, 0) +#include +#else +#include +#endif +#include +#include +#include +#include +#include +#ifdef CONFIG_COMPAT +#include +#endif +#include +#include +#include +#include +#include +#include + + +MODULE_AUTHOR("Chia-I Wu"); +MODULE_VERSION(VHBA_VERSION); +MODULE_DESCRIPTION("Virtual SCSI HBA"); +MODULE_LICENSE("GPL"); + + +#if LINUX_VERSION_CODE < KERNEL_VERSION(3, 15, 0) +#define sdev_dbg(sdev, fmt, a...) \ + dev_dbg(&(sdev)->sdev_gendev, fmt, ##a) +#define scmd_dbg(scmd, fmt, a...) \ + dev_dbg(&(scmd)->device->sdev_gendev, fmt, ##a) +#endif + +#define VHBA_MAX_SECTORS_PER_IO 256 +#define VHBA_MAX_BUS 16 +#define VHBA_MAX_ID 16 +#define VHBA_MAX_DEVICES (VHBA_MAX_BUS * (VHBA_MAX_ID-1)) +#define VHBA_KBUF_SIZE PAGE_SIZE + +#define DATA_TO_DEVICE(dir) ((dir) == DMA_TO_DEVICE || (dir) == DMA_BIDIRECTIONAL) +#define DATA_FROM_DEVICE(dir) ((dir) == DMA_FROM_DEVICE || (dir) == DMA_BIDIRECTIONAL) + + +static int vhba_can_queue = 32; +module_param_named(can_queue, vhba_can_queue, int, 0); + + +enum vhba_req_state { + VHBA_REQ_FREE, + VHBA_REQ_PENDING, + VHBA_REQ_READING, + VHBA_REQ_SENT, + VHBA_REQ_WRITING, +}; + +struct vhba_command { + struct scsi_cmnd *cmd; + /* metatags are per-host. not to be confused with + queue tags that are usually per-lun */ + unsigned long metatag; + int status; + struct list_head entry; +}; + +struct vhba_device { + unsigned int num; + spinlock_t cmd_lock; + struct list_head cmd_list; + wait_queue_head_t cmd_wq; + atomic_t refcnt; + + unsigned char *kbuf; + size_t kbuf_size; +}; + +struct vhba_host { + struct Scsi_Host *shost; + spinlock_t cmd_lock; + int cmd_next; + struct vhba_command *commands; + spinlock_t dev_lock; + struct vhba_device *devices[VHBA_MAX_DEVICES]; + int num_devices; + DECLARE_BITMAP(chgmap, VHBA_MAX_DEVICES); + int chgtype[VHBA_MAX_DEVICES]; + struct work_struct scan_devices; +}; + +#define MAX_COMMAND_SIZE 16 + +struct vhba_request { + __u32 metatag; + __u32 lun; + __u8 cdb[MAX_COMMAND_SIZE]; + __u8 cdb_len; + __u32 data_len; +}; + +struct vhba_response { + __u32 metatag; + __u32 status; + __u32 data_len; +}; + + + +static struct vhba_command *vhba_alloc_command (void); +static void vhba_free_command (struct vhba_command *vcmd); + +static struct platform_device vhba_platform_device; + + + +/* These functions define a symmetric 1:1 mapping between device numbers and + the bus and id. We have reserved the last id per bus for the host itself. */ +static void devnum_to_bus_and_id(unsigned int devnum, unsigned int *bus, unsigned int *id) +{ + *bus = devnum / (VHBA_MAX_ID-1); + *id = devnum % (VHBA_MAX_ID-1); +} + +static unsigned int bus_and_id_to_devnum(unsigned int bus, unsigned int id) +{ + return (bus * (VHBA_MAX_ID-1)) + id; +} + +static struct vhba_device *vhba_device_alloc (void) +{ + struct vhba_device *vdev; + + vdev = kzalloc(sizeof(struct vhba_device), GFP_KERNEL); + if (!vdev) { + return NULL; + } + + spin_lock_init(&vdev->cmd_lock); + INIT_LIST_HEAD(&vdev->cmd_list); + init_waitqueue_head(&vdev->cmd_wq); + atomic_set(&vdev->refcnt, 1); + + vdev->kbuf = NULL; + vdev->kbuf_size = 0; + + return vdev; +} + +static void vhba_device_put (struct vhba_device *vdev) +{ + if (atomic_dec_and_test(&vdev->refcnt)) { + kfree(vdev); + } +} + +static struct vhba_device *vhba_device_get (struct vhba_device *vdev) +{ + atomic_inc(&vdev->refcnt); + + return vdev; +} + +static int vhba_device_queue (struct vhba_device *vdev, struct scsi_cmnd *cmd) +{ + struct vhba_host *vhost; + struct vhba_command *vcmd; + unsigned long flags; + + vhost = platform_get_drvdata(&vhba_platform_device); + + vcmd = vhba_alloc_command(); + if (!vcmd) { + return SCSI_MLQUEUE_HOST_BUSY; + } + + vcmd->cmd = cmd; + + spin_lock_irqsave(&vdev->cmd_lock, flags); +#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 15, 0) + vcmd->metatag = scsi_cmd_to_rq(vcmd->cmd)->tag; +#else + vcmd->metatag = vcmd->cmd->request->tag; +#endif + list_add_tail(&vcmd->entry, &vdev->cmd_list); + spin_unlock_irqrestore(&vdev->cmd_lock, flags); + + wake_up_interruptible(&vdev->cmd_wq); + + return 0; +} + +static int vhba_device_dequeue (struct vhba_device *vdev, struct scsi_cmnd *cmd) +{ + struct vhba_command *vcmd; + int retval; + unsigned long flags; + + spin_lock_irqsave(&vdev->cmd_lock, flags); + list_for_each_entry(vcmd, &vdev->cmd_list, entry) { + if (vcmd->cmd == cmd) { + list_del_init(&vcmd->entry); + break; + } + } + + /* command not found */ + if (&vcmd->entry == &vdev->cmd_list) { + spin_unlock_irqrestore(&vdev->cmd_lock, flags); + return SUCCESS; + } + + while (vcmd->status == VHBA_REQ_READING || vcmd->status == VHBA_REQ_WRITING) { + spin_unlock_irqrestore(&vdev->cmd_lock, flags); + scmd_dbg(cmd, "wait for I/O before aborting\n"); + schedule_timeout(1); + spin_lock_irqsave(&vdev->cmd_lock, flags); + } + + retval = (vcmd->status == VHBA_REQ_SENT) ? FAILED : SUCCESS; + + vhba_free_command(vcmd); + + spin_unlock_irqrestore(&vdev->cmd_lock, flags); + + return retval; +} + +#if LINUX_VERSION_CODE < KERNEL_VERSION(3, 19, 0) +static int vhba_slave_alloc(struct scsi_device *sdev) +{ + struct Scsi_Host *shost = sdev->host; + + sdev_dbg(sdev, "enabling tagging (queue depth: %i).\n", sdev->queue_depth); +#if LINUX_VERSION_CODE >= KERNEL_VERSION(3, 17, 0) + if (!shost_use_blk_mq(shost) && shost->bqt) { +#else + if (shost->bqt) { +#endif + blk_queue_init_tags(sdev->request_queue, sdev->queue_depth, shost->bqt); + } + scsi_adjust_queue_depth(sdev, 0, sdev->queue_depth); + + return 0; +} +#endif + +static void vhba_scan_devices_add (struct vhba_host *vhost, int bus, int id) +{ + struct scsi_device *sdev; + + sdev = scsi_device_lookup(vhost->shost, bus, id, 0); + if (!sdev) { + scsi_add_device(vhost->shost, bus, id, 0); + } else { + dev_warn(&vhost->shost->shost_gendev, "tried to add an already-existing device %d:%d:0!\n", bus, id); + scsi_device_put(sdev); + } +} + +static void vhba_scan_devices_remove (struct vhba_host *vhost, int bus, int id) +{ + struct scsi_device *sdev; + + sdev = scsi_device_lookup(vhost->shost, bus, id, 0); + if (sdev) { + scsi_remove_device(sdev); + scsi_device_put(sdev); + } else { + dev_warn(&vhost->shost->shost_gendev, "tried to remove non-existing device %d:%d:0!\n", bus, id); + } +} + +static void vhba_scan_devices (struct work_struct *work) +{ + struct vhba_host *vhost = container_of(work, struct vhba_host, scan_devices); + unsigned long flags; + int change, exists; + unsigned int devnum; + unsigned int bus, id; + + for (;;) { + spin_lock_irqsave(&vhost->dev_lock, flags); + + devnum = find_first_bit(vhost->chgmap, VHBA_MAX_DEVICES); + if (devnum >= VHBA_MAX_DEVICES) { + spin_unlock_irqrestore(&vhost->dev_lock, flags); + break; + } + change = vhost->chgtype[devnum]; + exists = vhost->devices[devnum] != NULL; + + vhost->chgtype[devnum] = 0; + clear_bit(devnum, vhost->chgmap); + + spin_unlock_irqrestore(&vhost->dev_lock, flags); + + devnum_to_bus_and_id(devnum, &bus, &id); + + if (change < 0) { + dev_dbg(&vhost->shost->shost_gendev, "trying to remove target %d:%d:0\n", bus, id); + vhba_scan_devices_remove(vhost, bus, id); + } else if (change > 0) { + dev_dbg(&vhost->shost->shost_gendev, "trying to add target %d:%d:0\n", bus, id); + vhba_scan_devices_add(vhost, bus, id); + } else { + /* quick sequence of add/remove or remove/add; we determine + which one it was by checking if device structure exists */ + if (exists) { + /* remove followed by add: remove and (re)add */ + dev_dbg(&vhost->shost->shost_gendev, "trying to (re)add target %d:%d:0\n", bus, id); + vhba_scan_devices_remove(vhost, bus, id); + vhba_scan_devices_add(vhost, bus, id); + } else { + /* add followed by remove: no-op */ + dev_dbg(&vhost->shost->shost_gendev, "no-op for target %d:%d:0\n", bus, id); + } + } + } +} + +static int vhba_add_device (struct vhba_device *vdev) +{ + struct vhba_host *vhost; + unsigned int devnum; + unsigned long flags; + + vhost = platform_get_drvdata(&vhba_platform_device); + + vhba_device_get(vdev); + + spin_lock_irqsave(&vhost->dev_lock, flags); + if (vhost->num_devices >= VHBA_MAX_DEVICES) { + spin_unlock_irqrestore(&vhost->dev_lock, flags); + vhba_device_put(vdev); + return -EBUSY; + } + + for (devnum = 0; devnum < VHBA_MAX_DEVICES; devnum++) { + if (vhost->devices[devnum] == NULL) { + vdev->num = devnum; + vhost->devices[devnum] = vdev; + vhost->num_devices++; + set_bit(devnum, vhost->chgmap); + vhost->chgtype[devnum]++; + break; + } + } + spin_unlock_irqrestore(&vhost->dev_lock, flags); + + schedule_work(&vhost->scan_devices); + + return 0; +} + +static int vhba_remove_device (struct vhba_device *vdev) +{ + struct vhba_host *vhost; + unsigned long flags; + + vhost = platform_get_drvdata(&vhba_platform_device); + + spin_lock_irqsave(&vhost->dev_lock, flags); + set_bit(vdev->num, vhost->chgmap); + vhost->chgtype[vdev->num]--; + vhost->devices[vdev->num] = NULL; + vhost->num_devices--; + spin_unlock_irqrestore(&vhost->dev_lock, flags); + + vhba_device_put(vdev); + + schedule_work(&vhost->scan_devices); + + return 0; +} + +static struct vhba_device *vhba_lookup_device (int devnum) +{ + struct vhba_host *vhost; + struct vhba_device *vdev = NULL; + unsigned long flags; + + vhost = platform_get_drvdata(&vhba_platform_device); + + if (likely(devnum < VHBA_MAX_DEVICES)) { + spin_lock_irqsave(&vhost->dev_lock, flags); + vdev = vhost->devices[devnum]; + if (vdev) { + vdev = vhba_device_get(vdev); + } + + spin_unlock_irqrestore(&vhost->dev_lock, flags); + } + + return vdev; +} + +static struct vhba_command *vhba_alloc_command (void) +{ + struct vhba_host *vhost; + struct vhba_command *vcmd; + unsigned long flags; + int i; + + vhost = platform_get_drvdata(&vhba_platform_device); + + spin_lock_irqsave(&vhost->cmd_lock, flags); + + vcmd = vhost->commands + vhost->cmd_next++; + if (vcmd->status != VHBA_REQ_FREE) { + for (i = 0; i < vhba_can_queue; i++) { + vcmd = vhost->commands + i; + + if (vcmd->status == VHBA_REQ_FREE) { + vhost->cmd_next = i + 1; + break; + } + } + + if (i == vhba_can_queue) { + vcmd = NULL; + } + } + + if (vcmd) { + vcmd->status = VHBA_REQ_PENDING; + } + + vhost->cmd_next %= vhba_can_queue; + + spin_unlock_irqrestore(&vhost->cmd_lock, flags); + + return vcmd; +} + +static void vhba_free_command (struct vhba_command *vcmd) +{ + struct vhba_host *vhost; + unsigned long flags; + + vhost = platform_get_drvdata(&vhba_platform_device); + + spin_lock_irqsave(&vhost->cmd_lock, flags); + vcmd->status = VHBA_REQ_FREE; + spin_unlock_irqrestore(&vhost->cmd_lock, flags); +} + +static int vhba_queuecommand (struct Scsi_Host *shost, struct scsi_cmnd *cmd) +{ + struct vhba_device *vdev; + int retval; + unsigned int devnum; + +#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 15, 0) + scmd_dbg(cmd, "queue %p tag %i\n", cmd, scsi_cmd_to_rq(cmd)->tag); +#else + scmd_dbg(cmd, "queue %p tag %i\n", cmd, cmd->request->tag); +#endif + + devnum = bus_and_id_to_devnum(cmd->device->channel, cmd->device->id); + vdev = vhba_lookup_device(devnum); + if (!vdev) { + scmd_dbg(cmd, "no such device\n"); + + cmd->result = DID_NO_CONNECT << 16; +#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 16, 0) + scsi_done(cmd); +#else + cmd->scsi_done(cmd); +#endif + + return 0; + } + + retval = vhba_device_queue(vdev, cmd); + + vhba_device_put(vdev); + + return retval; +} + +static int vhba_abort (struct scsi_cmnd *cmd) +{ + struct vhba_device *vdev; + int retval = SUCCESS; + unsigned int devnum; + + scmd_dbg(cmd, "abort %p\n", cmd); + + devnum = bus_and_id_to_devnum(cmd->device->channel, cmd->device->id); + vdev = vhba_lookup_device(devnum); + if (vdev) { + retval = vhba_device_dequeue(vdev, cmd); + vhba_device_put(vdev); + } else { + cmd->result = DID_NO_CONNECT << 16; + } + + return retval; +} + +static struct scsi_host_template vhba_template = { + .module = THIS_MODULE, + .name = "vhba", + .proc_name = "vhba", + .queuecommand = vhba_queuecommand, + .eh_abort_handler = vhba_abort, + .this_id = -1, + .max_sectors = VHBA_MAX_SECTORS_PER_IO, + .sg_tablesize = 256, +#if LINUX_VERSION_CODE < KERNEL_VERSION(3, 19, 0) + .slave_alloc = vhba_slave_alloc, +#endif +#if LINUX_VERSION_CODE >= KERNEL_VERSION(4, 0, 0) && LINUX_VERSION_CODE < KERNEL_VERSION(6, 14, 0) + .tag_alloc_policy = BLK_TAG_ALLOC_RR, +#else + .tag_alloc_policy_rr = true, +#endif +#if LINUX_VERSION_CODE >= KERNEL_VERSION(3, 19, 0) && LINUX_VERSION_CODE < KERNEL_VERSION(4, 4, 0) + .use_blk_tags = 1, +#endif +#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 0, 0) + .max_segment_size = VHBA_KBUF_SIZE, +#endif +}; + +static ssize_t do_request (struct vhba_device *vdev, unsigned long metatag, struct scsi_cmnd *cmd, char __user *buf, size_t buf_len) +{ + struct vhba_request vreq; + ssize_t ret; + + scmd_dbg(cmd, "request %lu (%p), cdb 0x%x, bufflen %d, sg count %d\n", + metatag, cmd, cmd->cmnd[0], scsi_bufflen(cmd), scsi_sg_count(cmd)); + + ret = sizeof(vreq); + if (DATA_TO_DEVICE(cmd->sc_data_direction)) { + ret += scsi_bufflen(cmd); + } + + if (ret > buf_len) { + scmd_dbg(cmd, "buffer too small (%zd < %zd) for a request\n", buf_len, ret); + return -EIO; + } + + vreq.metatag = metatag; + vreq.lun = cmd->device->lun; + memcpy(vreq.cdb, cmd->cmnd, MAX_COMMAND_SIZE); + vreq.cdb_len = cmd->cmd_len; + vreq.data_len = scsi_bufflen(cmd); + + if (copy_to_user(buf, &vreq, sizeof(vreq))) { + return -EFAULT; + } + + if (DATA_TO_DEVICE(cmd->sc_data_direction) && vreq.data_len) { + buf += sizeof(vreq); + + if (scsi_sg_count(cmd)) { + unsigned char *kaddr, *uaddr; + struct scatterlist *sglist = scsi_sglist(cmd); + struct scatterlist *sg; + int i; + + uaddr = (unsigned char *) buf; + + for_each_sg(sglist, sg, scsi_sg_count(cmd), i) { + size_t len = sg->length; + + if (len > vdev->kbuf_size) { + scmd_dbg(cmd, "segment size (%zu) exceeds kbuf size (%zu)!", len, vdev->kbuf_size); + len = vdev->kbuf_size; + } + + kaddr = kmap_atomic(sg_page(sg)); + memcpy(vdev->kbuf, kaddr + sg->offset, len); + kunmap_atomic(kaddr); + + if (copy_to_user(uaddr, vdev->kbuf, len)) { + return -EFAULT; + } + uaddr += len; + } + } else { + if (copy_to_user(buf, scsi_sglist(cmd), vreq.data_len)) { + return -EFAULT; + } + } + } + + return ret; +} + +static ssize_t do_response (struct vhba_device *vdev, unsigned long metatag, struct scsi_cmnd *cmd, const char __user *buf, size_t buf_len, struct vhba_response *res) +{ + ssize_t ret = 0; + + scmd_dbg(cmd, "response %lu (%p), status %x, data len %d, sg count %d\n", + metatag, cmd, res->status, res->data_len, scsi_sg_count(cmd)); + + if (res->status) { + if (res->data_len > SCSI_SENSE_BUFFERSIZE) { + scmd_dbg(cmd, "truncate sense (%d < %d)", SCSI_SENSE_BUFFERSIZE, res->data_len); + res->data_len = SCSI_SENSE_BUFFERSIZE; + } + + if (copy_from_user(cmd->sense_buffer, buf, res->data_len)) { + return -EFAULT; + } + + cmd->result = res->status; + + ret += res->data_len; + } else if (DATA_FROM_DEVICE(cmd->sc_data_direction) && scsi_bufflen(cmd)) { + size_t to_read; + + if (res->data_len > scsi_bufflen(cmd)) { + scmd_dbg(cmd, "truncate data (%d < %d)\n", scsi_bufflen(cmd), res->data_len); + res->data_len = scsi_bufflen(cmd); + } + + to_read = res->data_len; + + if (scsi_sg_count(cmd)) { + unsigned char *kaddr, *uaddr; + struct scatterlist *sglist = scsi_sglist(cmd); + struct scatterlist *sg; + int i; + + uaddr = (unsigned char *)buf; + + for_each_sg(sglist, sg, scsi_sg_count(cmd), i) { + size_t len = (sg->length < to_read) ? sg->length : to_read; + + if (len > vdev->kbuf_size) { + scmd_dbg(cmd, "segment size (%zu) exceeds kbuf size (%zu)!", len, vdev->kbuf_size); + len = vdev->kbuf_size; + } + + if (copy_from_user(vdev->kbuf, uaddr, len)) { + return -EFAULT; + } + uaddr += len; + + kaddr = kmap_atomic(sg_page(sg)); + memcpy(kaddr + sg->offset, vdev->kbuf, len); + kunmap_atomic(kaddr); + + to_read -= len; + if (to_read == 0) { + break; + } + } + } else { + if (copy_from_user(scsi_sglist(cmd), buf, res->data_len)) { + return -EFAULT; + } + + to_read -= res->data_len; + } + + scsi_set_resid(cmd, to_read); + + ret += res->data_len - to_read; + } + + return ret; +} + +static struct vhba_command *next_command (struct vhba_device *vdev) +{ + struct vhba_command *vcmd; + + list_for_each_entry(vcmd, &vdev->cmd_list, entry) { + if (vcmd->status == VHBA_REQ_PENDING) { + break; + } + } + + if (&vcmd->entry == &vdev->cmd_list) { + vcmd = NULL; + } + + return vcmd; +} + +static struct vhba_command *match_command (struct vhba_device *vdev, __u32 metatag) +{ + struct vhba_command *vcmd; + + list_for_each_entry(vcmd, &vdev->cmd_list, entry) { + if (vcmd->metatag == metatag) { + break; + } + } + + if (&vcmd->entry == &vdev->cmd_list) { + vcmd = NULL; + } + + return vcmd; +} + +static struct vhba_command *wait_command (struct vhba_device *vdev, unsigned long flags) +{ + struct vhba_command *vcmd; + DEFINE_WAIT(wait); + + while (!(vcmd = next_command(vdev))) { + if (signal_pending(current)) { + break; + } + + prepare_to_wait(&vdev->cmd_wq, &wait, TASK_INTERRUPTIBLE); + + spin_unlock_irqrestore(&vdev->cmd_lock, flags); + + schedule(); + + spin_lock_irqsave(&vdev->cmd_lock, flags); + } + + finish_wait(&vdev->cmd_wq, &wait); + if (vcmd) { + vcmd->status = VHBA_REQ_READING; + } + + return vcmd; +} + +static ssize_t vhba_ctl_read (struct file *file, char __user *buf, size_t buf_len, loff_t *offset) +{ + struct vhba_device *vdev; + struct vhba_command *vcmd; + ssize_t ret; + unsigned long flags; + + vdev = file->private_data; + + /* Get next command */ + if (file->f_flags & O_NONBLOCK) { + /* Non-blocking variant */ + spin_lock_irqsave(&vdev->cmd_lock, flags); + vcmd = next_command(vdev); + spin_unlock_irqrestore(&vdev->cmd_lock, flags); + + if (!vcmd) { + return -EWOULDBLOCK; + } + } else { + /* Blocking variant */ + spin_lock_irqsave(&vdev->cmd_lock, flags); + vcmd = wait_command(vdev, flags); + spin_unlock_irqrestore(&vdev->cmd_lock, flags); + + if (!vcmd) { + return -ERESTARTSYS; + } + } + + ret = do_request(vdev, vcmd->metatag, vcmd->cmd, buf, buf_len); + + spin_lock_irqsave(&vdev->cmd_lock, flags); + if (ret >= 0) { + vcmd->status = VHBA_REQ_SENT; + *offset += ret; + } else { + vcmd->status = VHBA_REQ_PENDING; + } + + spin_unlock_irqrestore(&vdev->cmd_lock, flags); + + return ret; +} + +static ssize_t vhba_ctl_write (struct file *file, const char __user *buf, size_t buf_len, loff_t *offset) +{ + struct vhba_device *vdev; + struct vhba_command *vcmd; + struct vhba_response res; + ssize_t ret; + unsigned long flags; + + if (buf_len < sizeof(res)) { + return -EIO; + } + + if (copy_from_user(&res, buf, sizeof(res))) { + return -EFAULT; + } + + vdev = file->private_data; + + spin_lock_irqsave(&vdev->cmd_lock, flags); + vcmd = match_command(vdev, res.metatag); + if (!vcmd || vcmd->status != VHBA_REQ_SENT) { + spin_unlock_irqrestore(&vdev->cmd_lock, flags); + pr_debug("ctl dev #%u not expecting response\n", vdev->num); + return -EIO; + } + vcmd->status = VHBA_REQ_WRITING; + spin_unlock_irqrestore(&vdev->cmd_lock, flags); + + ret = do_response(vdev, vcmd->metatag, vcmd->cmd, buf + sizeof(res), buf_len - sizeof(res), &res); + + spin_lock_irqsave(&vdev->cmd_lock, flags); + if (ret >= 0) { +#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 16, 0) + scsi_done(vcmd->cmd); +#else + vcmd->cmd->scsi_done(vcmd->cmd); +#endif + ret += sizeof(res); + + /* don't compete with vhba_device_dequeue */ + if (!list_empty(&vcmd->entry)) { + list_del_init(&vcmd->entry); + vhba_free_command(vcmd); + } + } else { + vcmd->status = VHBA_REQ_SENT; + } + + spin_unlock_irqrestore(&vdev->cmd_lock, flags); + + return ret; +} + +static long vhba_ctl_ioctl (struct file *file, unsigned int cmd, unsigned long arg) +{ + struct vhba_device *vdev = file->private_data; + struct vhba_host *vhost = platform_get_drvdata(&vhba_platform_device); + + switch (cmd) { + case 0xBEEF001: { + unsigned int ident[4]; /* host, channel, id, lun */ + + ident[0] = vhost->shost->host_no; + devnum_to_bus_and_id(vdev->num, &ident[1], &ident[2]); + ident[3] = 0; /* lun */ + + if (copy_to_user((void *) arg, ident, sizeof(ident))) { + return -EFAULT; + } + + return 0; + } + case 0xBEEF002: { + unsigned int devnum = vdev->num; + + if (copy_to_user((void *) arg, &devnum, sizeof(devnum))) { + return -EFAULT; + } + + return 0; + } + } + + return -ENOTTY; +} + +#ifdef CONFIG_COMPAT +static long vhba_ctl_compat_ioctl (struct file *file, unsigned int cmd, unsigned long arg) +{ + unsigned long compat_arg = (unsigned long)compat_ptr(arg); + return vhba_ctl_ioctl(file, cmd, compat_arg); +} +#endif + +static unsigned int vhba_ctl_poll (struct file *file, poll_table *wait) +{ + struct vhba_device *vdev = file->private_data; + unsigned int mask = 0; + unsigned long flags; + + poll_wait(file, &vdev->cmd_wq, wait); + + spin_lock_irqsave(&vdev->cmd_lock, flags); + if (next_command(vdev)) { + mask |= POLLIN | POLLRDNORM; + } + spin_unlock_irqrestore(&vdev->cmd_lock, flags); + + return mask; +} + +static int vhba_ctl_open (struct inode *inode, struct file *file) +{ + struct vhba_device *vdev; + int retval; + + pr_debug("ctl dev open\n"); + + /* check if vhba is probed */ + if (!platform_get_drvdata(&vhba_platform_device)) { + return -ENODEV; + } + + vdev = vhba_device_alloc(); + if (!vdev) { + return -ENOMEM; + } + + vdev->kbuf_size = VHBA_KBUF_SIZE; + vdev->kbuf = kzalloc(vdev->kbuf_size, GFP_KERNEL); + if (!vdev->kbuf) { + return -ENOMEM; + } + + if (!(retval = vhba_add_device(vdev))) { + file->private_data = vdev; + } + + vhba_device_put(vdev); + + return retval; +} + +static int vhba_ctl_release (struct inode *inode, struct file *file) +{ + struct vhba_device *vdev; + struct vhba_command *vcmd; + unsigned long flags; + + vdev = file->private_data; + + pr_debug("ctl dev release\n"); + + vhba_device_get(vdev); + vhba_remove_device(vdev); + + spin_lock_irqsave(&vdev->cmd_lock, flags); + list_for_each_entry(vcmd, &vdev->cmd_list, entry) { + WARN_ON(vcmd->status == VHBA_REQ_READING || vcmd->status == VHBA_REQ_WRITING); + + scmd_dbg(vcmd->cmd, "device released with command %lu (%p)\n", vcmd->metatag, vcmd->cmd); + vcmd->cmd->result = DID_NO_CONNECT << 16; +#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 16, 0) + scsi_done(vcmd->cmd); +#else + vcmd->cmd->scsi_done(vcmd->cmd); +#endif + vhba_free_command(vcmd); + } + INIT_LIST_HEAD(&vdev->cmd_list); + spin_unlock_irqrestore(&vdev->cmd_lock, flags); + + kfree(vdev->kbuf); + vdev->kbuf = NULL; + + vhba_device_put(vdev); + + return 0; +} + +static struct file_operations vhba_ctl_fops = { + .owner = THIS_MODULE, + .open = vhba_ctl_open, + .release = vhba_ctl_release, + .read = vhba_ctl_read, + .write = vhba_ctl_write, + .poll = vhba_ctl_poll, + .unlocked_ioctl = vhba_ctl_ioctl, +#ifdef CONFIG_COMPAT + .compat_ioctl = vhba_ctl_compat_ioctl, +#endif +}; + +static struct miscdevice vhba_miscdev = { + .minor = MISC_DYNAMIC_MINOR, + .name = "vhba_ctl", + .fops = &vhba_ctl_fops, +}; + +static int vhba_probe (struct platform_device *pdev) +{ + struct Scsi_Host *shost; + struct vhba_host *vhost; + int i; + + vhba_can_queue = clamp(vhba_can_queue, 1, 256); + + shost = scsi_host_alloc(&vhba_template, sizeof(struct vhba_host)); + if (!shost) { + return -ENOMEM; + } + + shost->max_channel = VHBA_MAX_BUS-1; + shost->max_id = VHBA_MAX_ID; + /* we don't support lun > 0 */ + shost->max_lun = 1; + shost->max_cmd_len = MAX_COMMAND_SIZE; + shost->can_queue = vhba_can_queue; + shost->cmd_per_lun = vhba_can_queue; + + vhost = (struct vhba_host *)shost->hostdata; + memset(vhost, 0, sizeof(struct vhba_host)); + + vhost->shost = shost; + vhost->num_devices = 0; + spin_lock_init(&vhost->dev_lock); + spin_lock_init(&vhost->cmd_lock); + INIT_WORK(&vhost->scan_devices, vhba_scan_devices); + vhost->cmd_next = 0; + vhost->commands = kzalloc(vhba_can_queue * sizeof(struct vhba_command), GFP_KERNEL); + if (!vhost->commands) { + return -ENOMEM; + } + + for (i = 0; i < vhba_can_queue; i++) { + vhost->commands[i].status = VHBA_REQ_FREE; + } + + platform_set_drvdata(pdev, vhost); + +#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 4, 0) + i = scsi_init_shared_tag_map(shost, vhba_can_queue); + if (i) return i; +#endif + + if (scsi_add_host(shost, &pdev->dev)) { + scsi_host_put(shost); + return -ENOMEM; + } + + return 0; +} + +#if LINUX_VERSION_CODE < KERNEL_VERSION(6, 11, 0) +static int vhba_remove (struct platform_device *pdev) +#else +static void vhba_remove (struct platform_device *pdev) +#endif +{ + struct vhba_host *vhost; + struct Scsi_Host *shost; + + vhost = platform_get_drvdata(pdev); + shost = vhost->shost; + + scsi_remove_host(shost); + scsi_host_put(shost); + + kfree(vhost->commands); + +#if LINUX_VERSION_CODE < KERNEL_VERSION(6, 11, 0) + return 0; +#endif +} + +static void vhba_release (struct device * dev) +{ + return; +} + +static struct platform_device vhba_platform_device = { + .name = "vhba", + .id = -1, + .dev = { + .release = vhba_release, + }, +}; + +static struct platform_driver vhba_platform_driver = { + .driver = { + .owner = THIS_MODULE, + .name = "vhba", + }, + .probe = vhba_probe, + .remove = vhba_remove, +}; + +static int __init vhba_init (void) +{ + int ret; + + ret = platform_device_register(&vhba_platform_device); + if (ret < 0) { + return ret; + } + + ret = platform_driver_register(&vhba_platform_driver); + if (ret < 0) { + platform_device_unregister(&vhba_platform_device); + return ret; + } + + ret = misc_register(&vhba_miscdev); + if (ret < 0) { + platform_driver_unregister(&vhba_platform_driver); + platform_device_unregister(&vhba_platform_device); + return ret; + } + + return 0; +} + +static void __exit vhba_exit(void) +{ + misc_deregister(&vhba_miscdev); + platform_driver_unregister(&vhba_platform_driver); + platform_device_unregister(&vhba_platform_device); +} + +module_init(vhba_init); +module_exit(vhba_exit); + diff --git a/include/linux/pagemap.h b/include/linux/pagemap.h index 47bfc6b1b632..435901dbc742 100644 --- a/include/linux/pagemap.h +++ b/include/linux/pagemap.h @@ -1373,7 +1373,7 @@ struct readahead_control { ._index = i, \ } -#define VM_READAHEAD_PAGES (SZ_128K / PAGE_SIZE) +#define VM_READAHEAD_PAGES (SZ_8M / PAGE_SIZE) void page_cache_ra_unbounded(struct readahead_control *, unsigned long nr_to_read, unsigned long lookahead_count); diff --git a/include/linux/user_namespace.h b/include/linux/user_namespace.h index 7183e5aca282..56573371a2f8 100644 --- a/include/linux/user_namespace.h +++ b/include/linux/user_namespace.h @@ -159,6 +159,8 @@ static inline void set_userns_rlimit_max(struct user_namespace *ns, #ifdef CONFIG_USER_NS +extern int unprivileged_userns_clone; + static inline struct user_namespace *get_user_ns(struct user_namespace *ns) { if (ns) @@ -192,6 +194,8 @@ extern bool current_in_userns(const struct user_namespace *target_ns); struct ns_common *ns_get_owner(struct ns_common *ns); #else +#define unprivileged_userns_clone 0 + static inline struct user_namespace *get_user_ns(struct user_namespace *ns) { return &init_user_ns; diff --git a/include/linux/wait.h b/include/linux/wait.h index 6d90ad974408..d04768b01364 100644 --- a/include/linux/wait.h +++ b/include/linux/wait.h @@ -163,6 +163,7 @@ static inline bool wq_has_sleeper(struct wait_queue_head *wq_head) extern void add_wait_queue(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry); extern void add_wait_queue_exclusive(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry); +extern void add_wait_queue_exclusive_lifo(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry); extern void add_wait_queue_priority(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry); extern void remove_wait_queue(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry); @@ -1192,6 +1193,7 @@ do { \ */ void prepare_to_wait(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry, int state); bool prepare_to_wait_exclusive(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry, int state); +void prepare_to_wait_exclusive_lifo(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry, int state); long prepare_to_wait_event(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry, int state); void finish_wait(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry); long wait_woken(struct wait_queue_entry *wq_entry, unsigned mode, long timeout); diff --git a/init/Kconfig b/init/Kconfig index 5ab47c346ef9..a800d7529f59 100644 --- a/init/Kconfig +++ b/init/Kconfig @@ -162,6 +162,10 @@ config THREAD_INFO_IN_TASK menu "General setup" +config CACHY + bool "Some kernel tweaks by CachyOS" + default y + config BROKEN bool @@ -1328,6 +1332,22 @@ config USER_NS If unsure, say N. +config USER_NS_UNPRIVILEGED + bool "Allow unprivileged users to create namespaces" + default y + depends on USER_NS + help + When disabled, unprivileged users will not be able to create + new namespaces. Allowing users to create their own namespaces + has been part of several recent local privilege escalation + exploits, so if you need user namespaces but are + paranoid^Wsecurity-conscious you want to disable this. + + This setting can be overridden at runtime via the + kernel.unprivileged_userns_clone sysctl. + + If unsure, say Y. + config PID_NS bool "PID Namespaces" default y @@ -1470,6 +1490,12 @@ config CC_OPTIMIZE_FOR_PERFORMANCE with the "-O2" compiler flag for best performance and most helpful compile-time warnings. +config CC_OPTIMIZE_FOR_PERFORMANCE_O3 + bool "Optimize more for performance (-O3)" + help + Choosing this option will pass "-O3" to your compiler to optimize + the kernel yet more for performance. + config CC_OPTIMIZE_FOR_SIZE bool "Optimize for size (-Os)" help diff --git a/kernel/Kconfig.hz b/kernel/Kconfig.hz index 38ef6d06888e..0f78364efd4f 100644 --- a/kernel/Kconfig.hz +++ b/kernel/Kconfig.hz @@ -40,6 +40,27 @@ choice on SMP and NUMA systems and exactly dividing by both PAL and NTSC frame rates for video and multimedia work. + config HZ_500 + bool "500 HZ" + help + 500 Hz is a balanced timer frequency. Provides fast interactivity + on desktops with good smoothness without increasing CPU power + consumption and sacrificing the battery life on laptops. + + config HZ_600 + bool "600 HZ" + help + 600 Hz is a balanced timer frequency. Provides fast interactivity + on desktops with good smoothness without increasing CPU power + consumption and sacrificing the battery life on laptops. + + config HZ_750 + bool "750 HZ" + help + 750 Hz is a balanced timer frequency. Provides fast interactivity + on desktops with good smoothness without increasing CPU power + consumption and sacrificing the battery life on laptops. + config HZ_1000 bool "1000 HZ" help @@ -53,6 +74,9 @@ config HZ default 100 if HZ_100 default 250 if HZ_250 default 300 if HZ_300 + default 500 if HZ_500 + default 600 if HZ_600 + default 750 if HZ_750 default 1000 if HZ_1000 config SCHED_HRTICK diff --git a/kernel/Kconfig.preempt b/kernel/Kconfig.preempt index 54ea59ff8fbe..18f87e0dd137 100644 --- a/kernel/Kconfig.preempt +++ b/kernel/Kconfig.preempt @@ -88,7 +88,7 @@ endchoice config PREEMPT_RT bool "Fully Preemptible Kernel (Real-Time)" - depends on EXPERT && ARCH_SUPPORTS_RT && !COMPILE_TEST + depends on ARCH_SUPPORTS_RT && !COMPILE_TEST select PREEMPTION help This option turns the kernel into a real-time kernel by replacing diff --git a/kernel/fork.c b/kernel/fork.c index ca2ca3884f76..d47d85b68dab 100644 --- a/kernel/fork.c +++ b/kernel/fork.c @@ -106,6 +106,10 @@ #include #include +#ifdef CONFIG_USER_NS +#include +#endif + #include #include #include @@ -2171,6 +2175,10 @@ __latent_entropy struct task_struct *copy_process( if ((clone_flags & (CLONE_NEWUSER|CLONE_FS)) == (CLONE_NEWUSER|CLONE_FS)) return ERR_PTR(-EINVAL); + if ((clone_flags & CLONE_NEWUSER) && !unprivileged_userns_clone) + if (!capable(CAP_SYS_ADMIN)) + return ERR_PTR(-EPERM); + /* * Thread groups must share signals as well, and detached threads * can only be started up within the thread group. @@ -3324,6 +3332,12 @@ int ksys_unshare(unsigned long unshare_flags) if (unshare_flags & CLONE_NEWNS) unshare_flags |= CLONE_FS; + if ((unshare_flags & CLONE_NEWUSER) && !unprivileged_userns_clone) { + err = -EPERM; + if (!capable(CAP_SYS_ADMIN)) + goto bad_unshare_out; + } + err = check_unshare_flags(unshare_flags); if (err) goto bad_unshare_out; diff --git a/kernel/locking/rwsem.c b/kernel/locking/rwsem.c index 2ddb827e3bea..464049c4af3f 100644 --- a/kernel/locking/rwsem.c +++ b/kernel/locking/rwsem.c @@ -747,6 +747,7 @@ rwsem_spin_on_owner(struct rw_semaphore *sem) struct task_struct *new, *owner; unsigned long flags, new_flags; enum owner_state state; + int i = 0; lockdep_assert_preemption_disabled(); @@ -783,7 +784,8 @@ rwsem_spin_on_owner(struct rw_semaphore *sem) break; } - cpu_relax(); + if (i++ > 1000) + cpu_relax(); } return state; diff --git a/kernel/sched/fair.c b/kernel/sched/fair.c index e197b0edcfb2..7e0ac34b351d 100644 --- a/kernel/sched/fair.c +++ b/kernel/sched/fair.c @@ -76,10 +76,19 @@ unsigned int sysctl_sched_tunable_scaling = SCHED_TUNABLESCALING_LOG; * * (default: 0.75 msec * (1 + ilog(ncpus)), units: nanoseconds) */ +#ifdef CONFIG_CACHY +unsigned int sysctl_sched_base_slice = 350000ULL; +static unsigned int normalized_sysctl_sched_base_slice = 350000ULL; +#else unsigned int sysctl_sched_base_slice = 750000ULL; static unsigned int normalized_sysctl_sched_base_slice = 750000ULL; +#endif +#ifdef CONFIG_CACHY +const_debug unsigned int sysctl_sched_migration_cost = 300000UL; +#else const_debug unsigned int sysctl_sched_migration_cost = 500000UL; +#endif static int __init setup_sched_thermal_decay_shift(char *str) { @@ -124,8 +133,12 @@ int __weak arch_asym_cpu_priority(int cpu) * * (default: 5 msec, units: microseconds) */ +#ifdef CONFIG_CACHY +static unsigned int sysctl_sched_cfs_bandwidth_slice = 3000UL; +#else static unsigned int sysctl_sched_cfs_bandwidth_slice = 5000UL; #endif +#endif #ifdef CONFIG_NUMA_BALANCING /* Restrict the NUMA promotion throughput (MB/s) for each target node. */ @@ -9441,12 +9454,11 @@ int can_migrate_task(struct task_struct *p, struct lb_env *env) return 0; /* Prevent to re-select dst_cpu via env's CPUs: */ - for_each_cpu_and(cpu, env->dst_grpmask, env->cpus) { - if (cpumask_test_cpu(cpu, p->cpus_ptr)) { - env->flags |= LBF_DST_PINNED; - env->new_dst_cpu = cpu; - break; - } + cpu = cpumask_first_and_and(env->dst_grpmask, env->cpus, p->cpus_ptr); + + if (cpu < nr_cpu_ids) { + env->flags |= LBF_DST_PINNED; + env->new_dst_cpu = cpu; } return 0; diff --git a/kernel/sched/sched.h b/kernel/sched/sched.h index 1aa65a0ac586..4711759b84c2 100644 --- a/kernel/sched/sched.h +++ b/kernel/sched/sched.h @@ -2837,7 +2837,7 @@ extern void deactivate_task(struct rq *rq, struct task_struct *p, int flags); extern void wakeup_preempt(struct rq *rq, struct task_struct *p, int flags); -#ifdef CONFIG_PREEMPT_RT +#if defined(CONFIG_PREEMPT_RT) || defined(CONFIG_CACHY) # define SCHED_NR_MIGRATE_BREAK 8 #else # define SCHED_NR_MIGRATE_BREAK 32 diff --git a/kernel/sched/wait.c b/kernel/sched/wait.c index 51e38f5f4701..c5cc616484ba 100644 --- a/kernel/sched/wait.c +++ b/kernel/sched/wait.c @@ -47,6 +47,17 @@ void add_wait_queue_priority(struct wait_queue_head *wq_head, struct wait_queue_ } EXPORT_SYMBOL_GPL(add_wait_queue_priority); +void add_wait_queue_exclusive_lifo(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry) +{ + unsigned long flags; + + wq_entry->flags |= WQ_FLAG_EXCLUSIVE; + spin_lock_irqsave(&wq_head->lock, flags); + __add_wait_queue(wq_head, wq_entry); + spin_unlock_irqrestore(&wq_head->lock, flags); +} +EXPORT_SYMBOL(add_wait_queue_exclusive_lifo); + void remove_wait_queue(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry) { unsigned long flags; @@ -258,6 +269,19 @@ prepare_to_wait_exclusive(struct wait_queue_head *wq_head, struct wait_queue_ent } EXPORT_SYMBOL(prepare_to_wait_exclusive); +void prepare_to_wait_exclusive_lifo(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry, int state) +{ + unsigned long flags; + + wq_entry->flags |= WQ_FLAG_EXCLUSIVE; + spin_lock_irqsave(&wq_head->lock, flags); + if (list_empty(&wq_entry->entry)) + __add_wait_queue(wq_head, wq_entry); + set_current_state(state); + spin_unlock_irqrestore(&wq_head->lock, flags); +} +EXPORT_SYMBOL(prepare_to_wait_exclusive_lifo); + void init_wait_entry(struct wait_queue_entry *wq_entry, int flags) { wq_entry->flags = flags; diff --git a/kernel/sysctl.c b/kernel/sysctl.c index cb57da499ebb..f7f1c25b30fe 100644 --- a/kernel/sysctl.c +++ b/kernel/sysctl.c @@ -80,6 +80,9 @@ #ifdef CONFIG_RT_MUTEXES #include #endif +#ifdef CONFIG_USER_NS +#include +#endif /* shared constants to be used in various sysctls */ const int sysctl_vals[] = { 0, 1, 2, 3, 4, 100, 200, 1000, 3000, INT_MAX, 65535, -1 }; @@ -1617,6 +1620,15 @@ static const struct ctl_table kern_table[] = { .mode = 0644, .proc_handler = proc_dointvec, }, +#ifdef CONFIG_USER_NS + { + .procname = "unprivileged_userns_clone", + .data = &unprivileged_userns_clone, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, +#endif #ifdef CONFIG_PROC_SYSCTL { .procname = "tainted", diff --git a/kernel/user_namespace.c b/kernel/user_namespace.c index aa0b2e47f2f2..d74d857b1696 100644 --- a/kernel/user_namespace.c +++ b/kernel/user_namespace.c @@ -22,6 +22,13 @@ #include #include +/* sysctl */ +#ifdef CONFIG_USER_NS_UNPRIVILEGED +int unprivileged_userns_clone = 1; +#else +int unprivileged_userns_clone; +#endif + static struct kmem_cache *user_ns_cachep __ro_after_init; static DEFINE_MUTEX(userns_state_mutex); diff --git a/mm/Kconfig b/mm/Kconfig index 1b501db06417..c5b54d2197aa 100644 --- a/mm/Kconfig +++ b/mm/Kconfig @@ -691,7 +691,7 @@ config COMPACTION config COMPACT_UNEVICTABLE_DEFAULT int depends on COMPACTION - default 0 if PREEMPT_RT + default 0 if PREEMPT_RT || CACHY default 1 # diff --git a/mm/compaction.c b/mm/compaction.c index a3203d97123e..890344449a10 100644 --- a/mm/compaction.c +++ b/mm/compaction.c @@ -1923,7 +1923,11 @@ static int sysctl_compact_unevictable_allowed __read_mostly = CONFIG_COMPACT_UNE * aggressively the kernel should compact memory in the * background. It takes values in the range [0, 100]. */ +#ifdef CONFIG_CACHY +static unsigned int __read_mostly sysctl_compaction_proactiveness; +#else static unsigned int __read_mostly sysctl_compaction_proactiveness = 20; +#endif static int sysctl_extfrag_threshold = 500; static int __read_mostly sysctl_compact_memory; diff --git a/mm/huge_memory.c b/mm/huge_memory.c index 373781b21e5c..e68aafc49f1f 100644 --- a/mm/huge_memory.c +++ b/mm/huge_memory.c @@ -64,7 +64,11 @@ unsigned long transparent_hugepage_flags __read_mostly = #ifdef CONFIG_TRANSPARENT_HUGEPAGE_MADVISE (1<> (20 - PAGE_SHIFT); /* Use a smaller cluster for small-memory machines */ @@ -1092,4 +1096,5 @@ void __init swap_setup(void) * Right now other parts of the system means that we * _really_ don't want to cluster much more */ +#endif } diff --git a/mm/vmpressure.c b/mm/vmpressure.c index bd5183dfd879..3a410f53a07c 100644 --- a/mm/vmpressure.c +++ b/mm/vmpressure.c @@ -43,7 +43,11 @@ static const unsigned long vmpressure_win = SWAP_CLUSTER_MAX * 16; * essence, they are percents: the higher the value, the more number * unsuccessful reclaims there were. */ +#ifdef CONFIG_CACHY +static const unsigned int vmpressure_level_med = 65; +#else static const unsigned int vmpressure_level_med = 60; +#endif static const unsigned int vmpressure_level_critical = 95; /* diff --git a/mm/vmscan.c b/mm/vmscan.c index fada3b35aff8..04d97600f3e4 100644 --- a/mm/vmscan.c +++ b/mm/vmscan.c @@ -200,7 +200,11 @@ struct scan_control { /* * From 0 .. MAX_SWAPPINESS. Higher means more swappy. */ +#ifdef CONFIG_CACHY +int vm_swappiness = 100; +#else int vm_swappiness = 60; +#endif #ifdef CONFIG_MEMCG diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c index e4decfb270fa..38bff2d8a740 100644 --- a/net/ipv4/inet_connection_sock.c +++ b/net/ipv4/inet_connection_sock.c @@ -634,7 +634,7 @@ static int inet_csk_wait_for_connect(struct sock *sk, long timeo) * having to remove and re-insert us on the wait queue. */ for (;;) { - prepare_to_wait_exclusive(sk_sleep(sk), &wait, + prepare_to_wait_exclusive_lifo(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); release_sock(sk); if (reqsk_queue_empty(&icsk->icsk_accept_queue)) -- 2.49.0.391.g4bbb303af6 From cebdaab1a3aebe9e8986c80bd47e14e4e45b0db7 Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Sun, 20 Apr 2025 15:32:29 +0200 Subject: [PATCH 6/9] crypto Signed-off-by: Peter Jung --- MAINTAINERS | 1 + arch/x86/Kconfig | 2 +- arch/x86/crypto/aesni-intel_glue.c | 22 +- arch/x86/include/asm/cpufeatures.h | 1 + arch/x86/kernel/cpu/intel.c | 22 ++ arch/x86/lib/Makefile | 2 +- arch/x86/lib/crc-pclmul-consts.h | 99 +++++ arch/x86/lib/crc-pclmul-template.S | 584 ++++++++++++++++++++++++++++ arch/x86/lib/crc-pclmul-template.h | 81 ++++ arch/x86/lib/crc-t10dif-glue.c | 23 +- arch/x86/lib/crc16-msb-pclmul.S | 6 + arch/x86/lib/crc32-glue.c | 51 +-- arch/x86/lib/crc32-pclmul.S | 219 +---------- arch/x86/lib/crct10dif-pcl-asm_64.S | 332 ---------------- drivers/nvme/host/Kconfig | 3 +- drivers/nvme/host/tcp.c | 122 ++---- drivers/nvme/target/tcp.c | 90 ++--- include/linux/skbuff.h | 7 +- net/core/datagram.c | 46 +-- scripts/gen-crc-consts.py | 238 ++++++++++++ 20 files changed, 1143 insertions(+), 808 deletions(-) create mode 100644 arch/x86/lib/crc-pclmul-consts.h create mode 100644 arch/x86/lib/crc-pclmul-template.S create mode 100644 arch/x86/lib/crc-pclmul-template.h create mode 100644 arch/x86/lib/crc16-msb-pclmul.S delete mode 100644 arch/x86/lib/crct10dif-pcl-asm_64.S create mode 100755 scripts/gen-crc-consts.py diff --git a/MAINTAINERS b/MAINTAINERS index 00e94bec401e..3e00f5654f60 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -6140,6 +6140,7 @@ F: Documentation/staging/crc* F: arch/*/lib/crc* F: include/linux/crc* F: lib/crc* +F: scripts/gen-crc-consts.py CREATIVE SB0540 M: Bastien Nocera diff --git a/arch/x86/Kconfig b/arch/x86/Kconfig index 088f7555e1ac..8dc1133e3384 100644 --- a/arch/x86/Kconfig +++ b/arch/x86/Kconfig @@ -77,7 +77,7 @@ config X86 select ARCH_HAS_CPU_FINALIZE_INIT select ARCH_HAS_CPU_PASID if IOMMU_SVA select ARCH_HAS_CRC32 - select ARCH_HAS_CRC_T10DIF if X86_64 + select ARCH_HAS_CRC_T10DIF select ARCH_HAS_CURRENT_STACK_POINTER select ARCH_HAS_DEBUG_VIRTUAL select ARCH_HAS_DEBUG_VM_PGTABLE if !X86_PAE diff --git a/arch/x86/crypto/aesni-intel_glue.c b/arch/x86/crypto/aesni-intel_glue.c index 11e95fc62636..3e9ab5cdade4 100644 --- a/arch/x86/crypto/aesni-intel_glue.c +++ b/arch/x86/crypto/aesni-intel_glue.c @@ -1536,26 +1536,6 @@ DEFINE_GCM_ALGS(vaes_avx10_512, FLAG_AVX10_512, AES_GCM_KEY_AVX10_SIZE, 800); #endif /* CONFIG_AS_VAES && CONFIG_AS_VPCLMULQDQ */ -/* - * This is a list of CPU models that are known to suffer from downclocking when - * zmm registers (512-bit vectors) are used. On these CPUs, the AES mode - * implementations with zmm registers won't be used by default. Implementations - * with ymm registers (256-bit vectors) will be used by default instead. - */ -static const struct x86_cpu_id zmm_exclusion_list[] = { - X86_MATCH_VFM(INTEL_SKYLAKE_X, 0), - X86_MATCH_VFM(INTEL_ICELAKE_X, 0), - X86_MATCH_VFM(INTEL_ICELAKE_D, 0), - X86_MATCH_VFM(INTEL_ICELAKE, 0), - X86_MATCH_VFM(INTEL_ICELAKE_L, 0), - X86_MATCH_VFM(INTEL_ICELAKE_NNPI, 0), - X86_MATCH_VFM(INTEL_TIGERLAKE_L, 0), - X86_MATCH_VFM(INTEL_TIGERLAKE, 0), - /* Allow Rocket Lake and later, and Sapphire Rapids and later. */ - /* Also allow AMD CPUs (starting with Zen 4, the first with AVX-512). */ - {}, -}; - static int __init register_avx_algs(void) { int err; @@ -1600,7 +1580,7 @@ static int __init register_avx_algs(void) if (err) return err; - if (x86_match_cpu(zmm_exclusion_list)) { + if (boot_cpu_has(X86_FEATURE_PREFER_YMM)) { int i; aes_xts_alg_vaes_avx10_512.base.cra_priority = 1; diff --git a/arch/x86/include/asm/cpufeatures.h b/arch/x86/include/asm/cpufeatures.h index 8770dc185fe9..8c77b9337ca3 100644 --- a/arch/x86/include/asm/cpufeatures.h +++ b/arch/x86/include/asm/cpufeatures.h @@ -484,6 +484,7 @@ #define X86_FEATURE_AMD_FAST_CPPC (21*32 + 5) /* Fast CPPC */ #define X86_FEATURE_AMD_HETEROGENEOUS_CORES (21*32 + 6) /* Heterogeneous Core Topology */ #define X86_FEATURE_AMD_WORKLOAD_CLASS (21*32 + 7) /* Workload Classification */ +#define X86_FEATURE_PREFER_YMM (21*32 + 8) /* Avoid ZMM registers due to downclocking */ /* * BUG word(s) diff --git a/arch/x86/kernel/cpu/intel.c b/arch/x86/kernel/cpu/intel.c index 134368a3f4b1..5fe563eeb17d 100644 --- a/arch/x86/kernel/cpu/intel.c +++ b/arch/x86/kernel/cpu/intel.c @@ -521,6 +521,25 @@ static void init_intel_misc_features(struct cpuinfo_x86 *c) wrmsrl(MSR_MISC_FEATURES_ENABLES, msr); } +/* + * This is a list of Intel CPUs that are known to suffer from downclocking when + * ZMM registers (512-bit vectors) are used. On these CPUs, when the kernel + * executes SIMD-optimized code such as cryptography functions or CRCs, it + * should prefer 256-bit (YMM) code to 512-bit (ZMM) code. + */ +static const struct x86_cpu_id zmm_exclusion_list[] = { + X86_MATCH_VFM(INTEL_SKYLAKE_X, 0), + X86_MATCH_VFM(INTEL_ICELAKE_X, 0), + X86_MATCH_VFM(INTEL_ICELAKE_D, 0), + X86_MATCH_VFM(INTEL_ICELAKE, 0), + X86_MATCH_VFM(INTEL_ICELAKE_L, 0), + X86_MATCH_VFM(INTEL_ICELAKE_NNPI, 0), + X86_MATCH_VFM(INTEL_TIGERLAKE_L, 0), + X86_MATCH_VFM(INTEL_TIGERLAKE, 0), + /* Allow Rocket Lake and later, and Sapphire Rapids and later. */ + {}, +}; + static void init_intel(struct cpuinfo_x86 *c) { early_init_intel(c); @@ -601,6 +620,9 @@ static void init_intel(struct cpuinfo_x86 *c) } #endif + if (x86_match_cpu(zmm_exclusion_list)) + set_cpu_cap(c, X86_FEATURE_PREFER_YMM); + /* Work around errata */ srat_detect_node(c); diff --git a/arch/x86/lib/Makefile b/arch/x86/lib/Makefile index 8a59c61624c2..08496e221a7d 100644 --- a/arch/x86/lib/Makefile +++ b/arch/x86/lib/Makefile @@ -43,7 +43,7 @@ crc32-x86-y := crc32-glue.o crc32-pclmul.o crc32-x86-$(CONFIG_64BIT) += crc32c-3way.o obj-$(CONFIG_CRC_T10DIF_ARCH) += crc-t10dif-x86.o -crc-t10dif-x86-y := crc-t10dif-glue.o crct10dif-pcl-asm_64.o +crc-t10dif-x86-y := crc-t10dif-glue.o crc16-msb-pclmul.o obj-y += msr.o msr-reg.o msr-reg-export.o hweight.o obj-y += iomem.o diff --git a/arch/x86/lib/crc-pclmul-consts.h b/arch/x86/lib/crc-pclmul-consts.h new file mode 100644 index 000000000000..089954988f97 --- /dev/null +++ b/arch/x86/lib/crc-pclmul-consts.h @@ -0,0 +1,99 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * CRC constants generated by: + * + * ./scripts/gen-crc-consts.py x86_pclmul crc16_msb_0x8bb7,crc32_lsb_0xedb88320 + * + * Do not edit manually. + */ + +/* + * CRC folding constants generated for most-significant-bit-first CRC-16 using + * G(x) = x^16 + x^15 + x^11 + x^9 + x^8 + x^7 + x^5 + x^4 + x^2 + x^1 + x^0 + */ +static const struct { + u8 bswap_mask[16]; + u64 fold_across_2048_bits_consts[2]; + u64 fold_across_1024_bits_consts[2]; + u64 fold_across_512_bits_consts[2]; + u64 fold_across_256_bits_consts[2]; + u64 fold_across_128_bits_consts[2]; + u8 shuf_table[48]; + u64 barrett_reduction_consts[2]; +} crc16_msb_0x8bb7_consts ____cacheline_aligned __maybe_unused = { + .bswap_mask = {15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, + .fold_across_2048_bits_consts = { + 0xdccf000000000000, /* LO64_TERMS: (x^2000 mod G) * x^48 */ + 0x4b0b000000000000, /* HI64_TERMS: (x^2064 mod G) * x^48 */ + }, + .fold_across_1024_bits_consts = { + 0x9d9d000000000000, /* LO64_TERMS: (x^976 mod G) * x^48 */ + 0x7cf5000000000000, /* HI64_TERMS: (x^1040 mod G) * x^48 */ + }, + .fold_across_512_bits_consts = { + 0x044c000000000000, /* LO64_TERMS: (x^464 mod G) * x^48 */ + 0xe658000000000000, /* HI64_TERMS: (x^528 mod G) * x^48 */ + }, + .fold_across_256_bits_consts = { + 0x6ee3000000000000, /* LO64_TERMS: (x^208 mod G) * x^48 */ + 0xe7b5000000000000, /* HI64_TERMS: (x^272 mod G) * x^48 */ + }, + .fold_across_128_bits_consts = { + 0x2d56000000000000, /* LO64_TERMS: (x^80 mod G) * x^48 */ + 0x06df000000000000, /* HI64_TERMS: (x^144 mod G) * x^48 */ + }, + .shuf_table = { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + }, + .barrett_reduction_consts = { + 0x8bb7000000000000, /* LO64_TERMS: (G - x^16) * x^48 */ + 0xf65a57f81d33a48a, /* HI64_TERMS: (floor(x^79 / G) * x) - x^64 */ + }, +}; + +/* + * CRC folding constants generated for least-significant-bit-first CRC-32 using + * G(x) = x^32 + x^26 + x^23 + x^22 + x^16 + x^12 + x^11 + x^10 + x^8 + x^7 + + * x^5 + x^4 + x^2 + x^1 + x^0 + */ +static const struct { + u64 fold_across_2048_bits_consts[2]; + u64 fold_across_1024_bits_consts[2]; + u64 fold_across_512_bits_consts[2]; + u64 fold_across_256_bits_consts[2]; + u64 fold_across_128_bits_consts[2]; + u8 shuf_table[48]; + u64 barrett_reduction_consts[2]; +} crc32_lsb_0xedb88320_consts ____cacheline_aligned __maybe_unused = { + .fold_across_2048_bits_consts = { + 0x00000000ce3371cb, /* HI64_TERMS: (x^2079 mod G) * x^32 */ + 0x00000000e95c1271, /* LO64_TERMS: (x^2015 mod G) * x^32 */ + }, + .fold_across_1024_bits_consts = { + 0x0000000033fff533, /* HI64_TERMS: (x^1055 mod G) * x^32 */ + 0x00000000910eeec1, /* LO64_TERMS: (x^991 mod G) * x^32 */ + }, + .fold_across_512_bits_consts = { + 0x000000008f352d95, /* HI64_TERMS: (x^543 mod G) * x^32 */ + 0x000000001d9513d7, /* LO64_TERMS: (x^479 mod G) * x^32 */ + }, + .fold_across_256_bits_consts = { + 0x00000000f1da05aa, /* HI64_TERMS: (x^287 mod G) * x^32 */ + 0x0000000081256527, /* LO64_TERMS: (x^223 mod G) * x^32 */ + }, + .fold_across_128_bits_consts = { + 0x00000000ae689191, /* HI64_TERMS: (x^159 mod G) * x^32 */ + 0x00000000ccaa009e, /* LO64_TERMS: (x^95 mod G) * x^32 */ + }, + .shuf_table = { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + }, + .barrett_reduction_consts = { + 0xb4e5b025f7011641, /* HI64_TERMS: floor(x^95 / G) */ + 0x00000001db710640, /* LO64_TERMS: (G - x^32) * x^31 */ + }, +}; diff --git a/arch/x86/lib/crc-pclmul-template.S b/arch/x86/lib/crc-pclmul-template.S new file mode 100644 index 000000000000..dc91cc074b30 --- /dev/null +++ b/arch/x86/lib/crc-pclmul-template.S @@ -0,0 +1,584 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +// +// Template to generate [V]PCLMULQDQ-based CRC functions for x86 +// +// Copyright 2025 Google LLC +// +// Author: Eric Biggers + +#include + +// Offsets within the generated constants table +.set OFFSETOF_BSWAP_MASK, -5*16 // msb-first CRCs only +.set OFFSETOF_FOLD_ACROSS_2048_BITS_CONSTS, -4*16 // must precede next +.set OFFSETOF_FOLD_ACROSS_1024_BITS_CONSTS, -3*16 // must precede next +.set OFFSETOF_FOLD_ACROSS_512_BITS_CONSTS, -2*16 // must precede next +.set OFFSETOF_FOLD_ACROSS_256_BITS_CONSTS, -1*16 // must precede next +.set OFFSETOF_FOLD_ACROSS_128_BITS_CONSTS, 0*16 // must be 0 +.set OFFSETOF_SHUF_TABLE, 1*16 +.set OFFSETOF_BARRETT_REDUCTION_CONSTS, 4*16 + +// Emit a VEX (or EVEX) coded instruction if allowed, or emulate it using the +// corresponding non-VEX instruction plus any needed moves. The supported +// instruction formats are: +// +// - Two-arg [src, dst], where the non-VEX format is the same. +// - Three-arg [src1, src2, dst] where the non-VEX format is +// [src1, src2_and_dst]. If src2 != dst, then src1 must != dst too. +// +// \insn gives the instruction without a "v" prefix and including any immediate +// argument if needed to make the instruction follow one of the above formats. +// If \unaligned_mem_tmp is given, then the emitted non-VEX code moves \arg1 to +// it first; this is needed when \arg1 is an unaligned mem operand. +.macro _cond_vex insn:req, arg1:req, arg2:req, arg3, unaligned_mem_tmp +.if AVX_LEVEL == 0 + // VEX not allowed. Emulate it. + .ifnb \arg3 // Three-arg [src1, src2, dst] + .ifc "\arg2", "\arg3" // src2 == dst? + .ifnb \unaligned_mem_tmp + movdqu \arg1, \unaligned_mem_tmp + \insn \unaligned_mem_tmp, \arg3 + .else + \insn \arg1, \arg3 + .endif + .else // src2 != dst + .ifc "\arg1", "\arg3" + .error "Can't have src1 == dst when src2 != dst" + .endif + .ifnb \unaligned_mem_tmp + movdqu \arg1, \unaligned_mem_tmp + movdqa \arg2, \arg3 + \insn \unaligned_mem_tmp, \arg3 + .else + movdqa \arg2, \arg3 + \insn \arg1, \arg3 + .endif + .endif + .else // Two-arg [src, dst] + .ifnb \unaligned_mem_tmp + movdqu \arg1, \unaligned_mem_tmp + \insn \unaligned_mem_tmp, \arg2 + .else + \insn \arg1, \arg2 + .endif + .endif +.else + // VEX is allowed. Emit the desired instruction directly. + .ifnb \arg3 + v\insn \arg1, \arg2, \arg3 + .else + v\insn \arg1, \arg2 + .endif +.endif +.endm + +// Broadcast an aligned 128-bit mem operand to all 128-bit lanes of a vector +// register of length VL. +.macro _vbroadcast src, dst +.if VL == 16 + _cond_vex movdqa, \src, \dst +.elseif VL == 32 + vbroadcasti128 \src, \dst +.else + vbroadcasti32x4 \src, \dst +.endif +.endm + +// Load \vl bytes from the unaligned mem operand \src into \dst, and if the CRC +// is msb-first use \bswap_mask to reflect the bytes within each 128-bit lane. +.macro _load_data vl, src, bswap_mask, dst +.if \vl < 64 + _cond_vex movdqu, "\src", \dst +.else + vmovdqu8 \src, \dst +.endif +.if !LSB_CRC + _cond_vex pshufb, \bswap_mask, \dst, \dst +.endif +.endm + +.macro _prepare_v0 vl, v0, v1, bswap_mask +.if LSB_CRC + .if \vl < 64 + _cond_vex pxor, (BUF), \v0, \v0, unaligned_mem_tmp=\v1 + .else + vpxorq (BUF), \v0, \v0 + .endif +.else + _load_data \vl, (BUF), \bswap_mask, \v1 + .if \vl < 64 + _cond_vex pxor, \v1, \v0, \v0 + .else + vpxorq \v1, \v0, \v0 + .endif +.endif +.endm + +// The x^0..x^63 terms, i.e. poly128 mod x^64, i.e. the physically low qword for +// msb-first order or the physically high qword for lsb-first order +#define LO64_TERMS 0 + +// The x^64..x^127 terms, i.e. floor(poly128 / x^64), i.e. the physically high +// qword for msb-first order or the physically low qword for lsb-first order +#define HI64_TERMS 1 + +// Multiply the given \src1_terms of each 128-bit lane of \src1 by the given +// \src2_terms of each 128-bit lane of \src2, and write the result(s) to \dst. +.macro _pclmulqdq src1, src1_terms, src2, src2_terms, dst + _cond_vex "pclmulqdq $((\src1_terms ^ LSB_CRC) << 4) ^ (\src2_terms ^ LSB_CRC),", \ + \src1, \src2, \dst +.endm + +// Fold \acc into \data and store the result back into \acc. \data can be an +// unaligned mem operand if using VEX is allowed and the CRC is lsb-first so no +// byte-reflection is needed; otherwise it must be a vector register. \consts +// is a vector register containing the needed fold constants, and \tmp is a +// temporary vector register. All arguments must be the same length. +.macro _fold_vec acc, data, consts, tmp + _pclmulqdq \consts, HI64_TERMS, \acc, HI64_TERMS, \tmp + _pclmulqdq \consts, LO64_TERMS, \acc, LO64_TERMS, \acc +.if AVX_LEVEL < 10 + _cond_vex pxor, \data, \tmp, \tmp + _cond_vex pxor, \tmp, \acc, \acc +.else + vpternlogq $0x96, \data, \tmp, \acc +.endif +.endm + +// Fold \acc into \data and store the result back into \acc. \data is an +// unaligned mem operand, \consts is a vector register containing the needed +// fold constants, \bswap_mask is a vector register containing the +// byte-reflection table if the CRC is msb-first, and \tmp1 and \tmp2 are +// temporary vector registers. All arguments must have length \vl. +.macro _fold_vec_mem vl, acc, data, consts, bswap_mask, tmp1, tmp2 +.if AVX_LEVEL == 0 || !LSB_CRC + _load_data \vl, \data, \bswap_mask, \tmp1 + _fold_vec \acc, \tmp1, \consts, \tmp2 +.else + _fold_vec \acc, \data, \consts, \tmp1 +.endif +.endm + +// Load the constants for folding across 2**i vectors of length VL at a time +// into all 128-bit lanes of the vector register CONSTS. +.macro _load_vec_folding_consts i + _vbroadcast OFFSETOF_FOLD_ACROSS_128_BITS_CONSTS+(4-LOG2_VL-\i)*16(CONSTS_PTR), \ + CONSTS +.endm + +// Given vector registers \v0 and \v1 of length \vl, fold \v0 into \v1 and store +// the result back into \v0. If the remaining length mod \vl is nonzero, also +// fold \vl data bytes from BUF. For both operations the fold distance is \vl. +// \consts must be a register of length \vl containing the fold constants. +.macro _fold_vec_final vl, v0, v1, consts, bswap_mask, tmp1, tmp2 + _fold_vec \v0, \v1, \consts, \tmp1 + test $\vl, LEN8 + jz .Lfold_vec_final_done\@ + _fold_vec_mem \vl, \v0, (BUF), \consts, \bswap_mask, \tmp1, \tmp2 + add $\vl, BUF +.Lfold_vec_final_done\@: +.endm + +// This macro generates the body of a CRC function with the following prototype: +// +// crc_t crc_func(crc_t crc, const u8 *buf, size_t len, const void *consts); +// +// |crc| is the initial CRC, and crc_t is a data type wide enough to hold it. +// |buf| is the data to checksum. |len| is the data length in bytes, which must +// be at least 16. |consts| is a pointer to the fold_across_128_bits_consts +// field of the constants struct that was generated for the chosen CRC variant. +// +// Moving onto the macro parameters, \n is the number of bits in the CRC, e.g. +// 32 for a CRC-32. Currently the supported values are 8, 16, 32, and 64. If +// the file is compiled in i386 mode, then the maximum supported value is 32. +// +// \lsb_crc is 1 if the CRC processes the least significant bit of each byte +// first, i.e. maps bit0 to x^7, bit1 to x^6, ..., bit7 to x^0. \lsb_crc is 0 +// if the CRC processes the most significant bit of each byte first, i.e. maps +// bit0 to x^0, bit1 to x^1, bit7 to x^7. +// +// \vl is the maximum length of vector register to use in bytes: 16, 32, or 64. +// +// \avx_level is the level of AVX support to use: 0 for SSE only, 2 for AVX2, or +// 10 for AVX10 or AVX512. +// +// If \vl == 16 && \avx_level == 0, the generated code requires: +// PCLMULQDQ && SSE4.1. (Note: all known CPUs with PCLMULQDQ also have SSE4.1.) +// +// If \vl == 32 && \avx_level == 2, the generated code requires: +// VPCLMULQDQ && AVX2. +// +// If \vl == 32 && \avx_level == 10, the generated code requires: +// VPCLMULQDQ && (AVX10/256 || (AVX512BW && AVX512VL)) +// +// If \vl == 64 && \avx_level == 10, the generated code requires: +// VPCLMULQDQ && (AVX10/512 || (AVX512BW && AVX512VL)) +// +// Other \vl and \avx_level combinations are either not supported or not useful. +.macro _crc_pclmul n, lsb_crc, vl, avx_level + .set LSB_CRC, \lsb_crc + .set VL, \vl + .set AVX_LEVEL, \avx_level + + // Define aliases for the xmm, ymm, or zmm registers according to VL. +.irp i, 0,1,2,3,4,5,6,7 + .if VL == 16 + .set V\i, %xmm\i + .set LOG2_VL, 4 + .elseif VL == 32 + .set V\i, %ymm\i + .set LOG2_VL, 5 + .elseif VL == 64 + .set V\i, %zmm\i + .set LOG2_VL, 6 + .else + .error "Unsupported vector length" + .endif +.endr + // Define aliases for the function parameters. + // Note: when crc_t is shorter than u32, zero-extension to 32 bits is + // guaranteed by the ABI. Zero-extension to 64 bits is *not* guaranteed + // when crc_t is shorter than u64. +#ifdef __x86_64__ +.if \n <= 32 + .set CRC, %edi +.else + .set CRC, %rdi +.endif + .set BUF, %rsi + .set LEN, %rdx + .set LEN32, %edx + .set LEN8, %dl + .set CONSTS_PTR, %rcx +#else + // 32-bit support, assuming -mregparm=3 and not including support for + // CRC-64 (which would use both eax and edx to pass the crc parameter). + .set CRC, %eax + .set BUF, %edx + .set LEN, %ecx + .set LEN32, %ecx + .set LEN8, %cl + .set CONSTS_PTR, %ebx // Passed on stack +#endif + + // Define aliases for some local variables. V0-V5 are used without + // aliases (for accumulators, data, temporary values, etc). Staying + // within the first 8 vector registers keeps the code 32-bit SSE + // compatible and reduces the size of 64-bit SSE code slightly. + .set BSWAP_MASK, V6 + .set BSWAP_MASK_YMM, %ymm6 + .set BSWAP_MASK_XMM, %xmm6 + .set CONSTS, V7 + .set CONSTS_YMM, %ymm7 + .set CONSTS_XMM, %xmm7 + +#ifdef __i386__ + push CONSTS_PTR + mov 8(%esp), CONSTS_PTR +#endif + + // Create a 128-bit vector that contains the initial CRC in the end + // representing the high-order polynomial coefficients, and the rest 0. + // If the CRC is msb-first, also load the byte-reflection table. +.if \n <= 32 + _cond_vex movd, CRC, %xmm0 +.else + _cond_vex movq, CRC, %xmm0 +.endif +.if !LSB_CRC + _cond_vex pslldq, $(128-\n)/8, %xmm0, %xmm0 + _vbroadcast OFFSETOF_BSWAP_MASK(CONSTS_PTR), BSWAP_MASK +.endif + + // Load the first vector of data and XOR the initial CRC into the + // appropriate end of the first 128-bit lane of data. If LEN < VL, then + // use a short vector and jump ahead to the final reduction. (LEN >= 16 + // is guaranteed here but not necessarily LEN >= VL.) +.if VL >= 32 + cmp $VL, LEN + jae .Lat_least_1vec\@ + .if VL == 64 + cmp $32, LEN32 + jb .Lless_than_32bytes\@ + _prepare_v0 32, %ymm0, %ymm1, BSWAP_MASK_YMM + add $32, BUF + jmp .Lreduce_256bits_to_128bits\@ +.Lless_than_32bytes\@: + .endif + _prepare_v0 16, %xmm0, %xmm1, BSWAP_MASK_XMM + add $16, BUF + vmovdqa OFFSETOF_FOLD_ACROSS_128_BITS_CONSTS(CONSTS_PTR), CONSTS_XMM + jmp .Lcheck_for_partial_block\@ +.Lat_least_1vec\@: +.endif + _prepare_v0 VL, V0, V1, BSWAP_MASK + + // Handle VL <= LEN < 4*VL. + cmp $4*VL-1, LEN + ja .Lat_least_4vecs\@ + add $VL, BUF + // If VL <= LEN < 2*VL, then jump ahead to the reduction from 1 vector. + // If VL==16 then load fold_across_128_bits_consts first, as the final + // reduction depends on it and it won't be loaded anywhere else. + cmp $2*VL-1, LEN32 +.if VL == 16 + _cond_vex movdqa, OFFSETOF_FOLD_ACROSS_128_BITS_CONSTS(CONSTS_PTR), CONSTS_XMM +.endif + jbe .Lreduce_1vec_to_128bits\@ + // Otherwise 2*VL <= LEN < 4*VL. Load one more vector and jump ahead to + // the reduction from 2 vectors. + _load_data VL, (BUF), BSWAP_MASK, V1 + add $VL, BUF + jmp .Lreduce_2vecs_to_1\@ + +.Lat_least_4vecs\@: + // Load 3 more vectors of data. + _load_data VL, 1*VL(BUF), BSWAP_MASK, V1 + _load_data VL, 2*VL(BUF), BSWAP_MASK, V2 + _load_data VL, 3*VL(BUF), BSWAP_MASK, V3 + sub $-4*VL, BUF // Shorter than 'add 4*VL' when VL=32 + add $-4*VL, LEN // Shorter than 'sub 4*VL' when VL=32 + + // Main loop: while LEN >= 4*VL, fold the 4 vectors V0-V3 into the next + // 4 vectors of data and write the result back to V0-V3. + cmp $4*VL-1, LEN // Shorter than 'cmp 4*VL' when VL=32 + jbe .Lreduce_4vecs_to_2\@ + _load_vec_folding_consts 2 +.Lfold_4vecs_loop\@: + _fold_vec_mem VL, V0, 0*VL(BUF), CONSTS, BSWAP_MASK, V4, V5 + _fold_vec_mem VL, V1, 1*VL(BUF), CONSTS, BSWAP_MASK, V4, V5 + _fold_vec_mem VL, V2, 2*VL(BUF), CONSTS, BSWAP_MASK, V4, V5 + _fold_vec_mem VL, V3, 3*VL(BUF), CONSTS, BSWAP_MASK, V4, V5 + sub $-4*VL, BUF + add $-4*VL, LEN + cmp $4*VL-1, LEN + ja .Lfold_4vecs_loop\@ + + // Fold V0,V1 into V2,V3 and write the result back to V0,V1. Then fold + // two more vectors of data from BUF, if at least that much remains. +.Lreduce_4vecs_to_2\@: + _load_vec_folding_consts 1 + _fold_vec V0, V2, CONSTS, V4 + _fold_vec V1, V3, CONSTS, V4 + test $2*VL, LEN8 + jz .Lreduce_2vecs_to_1\@ + _fold_vec_mem VL, V0, 0*VL(BUF), CONSTS, BSWAP_MASK, V4, V5 + _fold_vec_mem VL, V1, 1*VL(BUF), CONSTS, BSWAP_MASK, V4, V5 + sub $-2*VL, BUF + + // Fold V0 into V1 and write the result back to V0. Then fold one more + // vector of data from BUF, if at least that much remains. +.Lreduce_2vecs_to_1\@: + _load_vec_folding_consts 0 + _fold_vec_final VL, V0, V1, CONSTS, BSWAP_MASK, V4, V5 + +.Lreduce_1vec_to_128bits\@: +.if VL == 64 + // Reduce 512-bit %zmm0 to 256-bit %ymm0. Then fold 256 more bits of + // data from BUF, if at least that much remains. + vbroadcasti128 OFFSETOF_FOLD_ACROSS_256_BITS_CONSTS(CONSTS_PTR), CONSTS_YMM + vextracti64x4 $1, %zmm0, %ymm1 + _fold_vec_final 32, %ymm0, %ymm1, CONSTS_YMM, BSWAP_MASK_YMM, %ymm4, %ymm5 +.Lreduce_256bits_to_128bits\@: +.endif +.if VL >= 32 + // Reduce 256-bit %ymm0 to 128-bit %xmm0. Then fold 128 more bits of + // data from BUF, if at least that much remains. + vmovdqa OFFSETOF_FOLD_ACROSS_128_BITS_CONSTS(CONSTS_PTR), CONSTS_XMM + vextracti128 $1, %ymm0, %xmm1 + _fold_vec_final 16, %xmm0, %xmm1, CONSTS_XMM, BSWAP_MASK_XMM, %xmm4, %xmm5 +.Lcheck_for_partial_block\@: +.endif + and $15, LEN32 + jz .Lreduce_128bits_to_crc\@ + + // 1 <= LEN <= 15 data bytes remain in BUF. The polynomial is now + // A*(x^(8*LEN)) + B, where A is the 128-bit polynomial stored in %xmm0 + // and B is the polynomial of the remaining LEN data bytes. To reduce + // this to 128 bits without needing fold constants for each possible + // LEN, rearrange this expression into C1*(x^128) + C2, where + // C1 = floor(A / x^(128 - 8*LEN)) and C2 = A*x^(8*LEN) + B mod x^128. + // Then fold C1 into C2, which is just another fold across 128 bits. + +.if !LSB_CRC || AVX_LEVEL == 0 + // Load the last 16 data bytes. Note that originally LEN was >= 16. + _load_data 16, "-16(BUF,LEN)", BSWAP_MASK_XMM, %xmm2 +.endif // Else will use vpblendvb mem operand later. +.if !LSB_CRC + neg LEN // Needed for indexing shuf_table +.endif + + // tmp = A*x^(8*LEN) mod x^128 + // lsb: pshufb by [LEN, LEN+1, ..., 15, -1, -1, ..., -1] + // i.e. right-shift by LEN bytes. + // msb: pshufb by [-1, -1, ..., -1, 0, 1, ..., 15-LEN] + // i.e. left-shift by LEN bytes. + _cond_vex movdqu, "OFFSETOF_SHUF_TABLE+16(CONSTS_PTR,LEN)", %xmm3 + _cond_vex pshufb, %xmm3, %xmm0, %xmm1 + + // C1 = floor(A / x^(128 - 8*LEN)) + // lsb: pshufb by [-1, -1, ..., -1, 0, 1, ..., LEN-1] + // i.e. left-shift by 16-LEN bytes. + // msb: pshufb by [16-LEN, 16-LEN+1, ..., 15, -1, -1, ..., -1] + // i.e. right-shift by 16-LEN bytes. + _cond_vex pshufb, "OFFSETOF_SHUF_TABLE+32*!LSB_CRC(CONSTS_PTR,LEN)", \ + %xmm0, %xmm0, unaligned_mem_tmp=%xmm4 + + // C2 = tmp + B. This is just a blend of tmp with the last 16 data + // bytes (reflected if msb-first). The blend mask is the shuffle table + // that was used to create tmp. 0 selects tmp, and 1 last16databytes. +.if AVX_LEVEL == 0 + movdqa %xmm0, %xmm4 + movdqa %xmm3, %xmm0 + pblendvb %xmm2, %xmm1 // uses %xmm0 as implicit operand + movdqa %xmm4, %xmm0 +.elseif LSB_CRC + vpblendvb %xmm3, -16(BUF,LEN), %xmm1, %xmm1 +.else + vpblendvb %xmm3, %xmm2, %xmm1, %xmm1 +.endif + + // Fold C1 into C2 and store the 128-bit result in %xmm0. + _fold_vec %xmm0, %xmm1, CONSTS_XMM, %xmm4 + +.Lreduce_128bits_to_crc\@: + // Compute the CRC as %xmm0 * x^n mod G. Here %xmm0 means the 128-bit + // polynomial stored in %xmm0 (using either lsb-first or msb-first bit + // order according to LSB_CRC), and G is the CRC's generator polynomial. + + // First, multiply %xmm0 by x^n and reduce the result to 64+n bits: + // + // t0 := (x^(64+n) mod G) * floor(%xmm0 / x^64) + + // x^n * (%xmm0 mod x^64) + // + // Store t0 * x^(64-n) in %xmm0. I.e., actually do: + // + // %xmm0 := ((x^(64+n) mod G) * x^(64-n)) * floor(%xmm0 / x^64) + + // x^64 * (%xmm0 mod x^64) + // + // The extra unreduced factor of x^(64-n) makes floor(t0 / x^n) aligned + // to the HI64_TERMS of %xmm0 so that the next pclmulqdq can easily + // select it. The 64-bit constant (x^(64+n) mod G) * x^(64-n) in the + // msb-first case, or (x^(63+n) mod G) * x^(64-n) in the lsb-first case + // (considering the extra factor of x that gets implicitly introduced by + // each pclmulqdq when using lsb-first order), is identical to the + // constant that was used earlier for folding the LO64_TERMS across 128 + // bits. Thus it's already available in LO64_TERMS of CONSTS_XMM. + _pclmulqdq CONSTS_XMM, LO64_TERMS, %xmm0, HI64_TERMS, %xmm1 +.if LSB_CRC + _cond_vex psrldq, $8, %xmm0, %xmm0 // x^64 * (%xmm0 mod x^64) +.else + _cond_vex pslldq, $8, %xmm0, %xmm0 // x^64 * (%xmm0 mod x^64) +.endif + _cond_vex pxor, %xmm1, %xmm0, %xmm0 + // The HI64_TERMS of %xmm0 now contain floor(t0 / x^n). + // The LO64_TERMS of %xmm0 now contain (t0 mod x^n) * x^(64-n). + + // First step of Barrett reduction: Compute floor(t0 / G). This is the + // polynomial by which G needs to be multiplied to cancel out the x^n + // and higher terms of t0, i.e. to reduce t0 mod G. First do: + // + // t1 := floor(x^(63+n) / G) * x * floor(t0 / x^n) + // + // Then the desired value floor(t0 / G) is floor(t1 / x^64). The 63 in + // x^(63+n) is the maximum degree of floor(t0 / x^n) and thus the lowest + // value that makes enough precision be carried through the calculation. + // + // The '* x' makes it so the result is floor(t1 / x^64) rather than + // floor(t1 / x^63), making it qword-aligned in HI64_TERMS so that it + // can be extracted much more easily in the next step. In the lsb-first + // case the '* x' happens implicitly. In the msb-first case it must be + // done explicitly; floor(x^(63+n) / G) * x is a 65-bit constant, so the + // constant passed to pclmulqdq is (floor(x^(63+n) / G) * x) - x^64, and + // the multiplication by the x^64 term is handled using a pxor. The + // pxor causes the low 64 terms of t1 to be wrong, but they are unused. + _cond_vex movdqa, OFFSETOF_BARRETT_REDUCTION_CONSTS(CONSTS_PTR), CONSTS_XMM + _pclmulqdq CONSTS_XMM, HI64_TERMS, %xmm0, HI64_TERMS, %xmm1 +.if !LSB_CRC + _cond_vex pxor, %xmm0, %xmm1, %xmm1 // += x^64 * floor(t0 / x^n) +.endif + // The HI64_TERMS of %xmm1 now contain floor(t1 / x^64) = floor(t0 / G). + + // Second step of Barrett reduction: Cancel out the x^n and higher terms + // of t0 by subtracting the needed multiple of G. This gives the CRC: + // + // crc := t0 - (G * floor(t0 / G)) + // + // But %xmm0 contains t0 * x^(64-n), so it's more convenient to do: + // + // crc := ((t0 * x^(64-n)) - ((G * x^(64-n)) * floor(t0 / G))) / x^(64-n) + // + // Furthermore, since the resulting CRC is n-bit, if mod x^n is + // explicitly applied to it then the x^n term of G makes no difference + // in the result and can be omitted. This helps keep the constant + // multiplier in 64 bits in most cases. This gives the following: + // + // %xmm0 := %xmm0 - (((G - x^n) * x^(64-n)) * floor(t0 / G)) + // crc := (%xmm0 / x^(64-n)) mod x^n + // + // In the lsb-first case, each pclmulqdq implicitly introduces + // an extra factor of x, so in that case the constant that needs to be + // passed to pclmulqdq is actually '(G - x^n) * x^(63-n)' when n <= 63. + // For lsb-first CRCs where n=64, the extra factor of x cannot be as + // easily avoided. In that case, instead pass '(G - x^n - x^0) / x' to + // pclmulqdq and handle the x^0 term (i.e. 1) separately. (All CRC + // polynomials have nonzero x^n and x^0 terms.) It works out as: the + // CRC has be XORed with the physically low qword of %xmm1, representing + // floor(t0 / G). The most efficient way to do that is to move it to + // the physically high qword and use a ternlog to combine the two XORs. +.if LSB_CRC && \n == 64 + _cond_vex punpcklqdq, %xmm1, %xmm2, %xmm2 + _pclmulqdq CONSTS_XMM, LO64_TERMS, %xmm1, HI64_TERMS, %xmm1 + .if AVX_LEVEL < 10 + _cond_vex pxor, %xmm2, %xmm0, %xmm0 + _cond_vex pxor, %xmm1, %xmm0, %xmm0 + .else + vpternlogq $0x96, %xmm2, %xmm1, %xmm0 + .endif + _cond_vex "pextrq $1,", %xmm0, %rax // (%xmm0 / x^0) mod x^64 +.else + _pclmulqdq CONSTS_XMM, LO64_TERMS, %xmm1, HI64_TERMS, %xmm1 + _cond_vex pxor, %xmm1, %xmm0, %xmm0 + .if \n == 8 + _cond_vex "pextrb $7 + LSB_CRC,", %xmm0, %eax // (%xmm0 / x^56) mod x^8 + .elseif \n == 16 + _cond_vex "pextrw $3 + LSB_CRC,", %xmm0, %eax // (%xmm0 / x^48) mod x^16 + .elseif \n == 32 + _cond_vex "pextrd $1 + LSB_CRC,", %xmm0, %eax // (%xmm0 / x^32) mod x^32 + .else // \n == 64 && !LSB_CRC + _cond_vex movq, %xmm0, %rax // (%xmm0 / x^0) mod x^64 + .endif +.endif + +.if VL > 16 + vzeroupper // Needed when ymm or zmm registers may have been used. +.endif +#ifdef __i386__ + pop CONSTS_PTR +#endif + RET +.endm + +#ifdef CONFIG_AS_VPCLMULQDQ +#define DEFINE_CRC_PCLMUL_FUNCS(prefix, bits, lsb) \ +SYM_FUNC_START(prefix##_pclmul_sse); \ + _crc_pclmul n=bits, lsb_crc=lsb, vl=16, avx_level=0; \ +SYM_FUNC_END(prefix##_pclmul_sse); \ + \ +SYM_FUNC_START(prefix##_vpclmul_avx2); \ + _crc_pclmul n=bits, lsb_crc=lsb, vl=32, avx_level=2; \ +SYM_FUNC_END(prefix##_vpclmul_avx2); \ + \ +SYM_FUNC_START(prefix##_vpclmul_avx10_256); \ + _crc_pclmul n=bits, lsb_crc=lsb, vl=32, avx_level=10; \ +SYM_FUNC_END(prefix##_vpclmul_avx10_256); \ + \ +SYM_FUNC_START(prefix##_vpclmul_avx10_512); \ + _crc_pclmul n=bits, lsb_crc=lsb, vl=64, avx_level=10; \ +SYM_FUNC_END(prefix##_vpclmul_avx10_512); +#else +#define DEFINE_CRC_PCLMUL_FUNCS(prefix, bits, lsb) \ +SYM_FUNC_START(prefix##_pclmul_sse); \ + _crc_pclmul n=bits, lsb_crc=lsb, vl=16, avx_level=0; \ +SYM_FUNC_END(prefix##_pclmul_sse); +#endif // !CONFIG_AS_VPCLMULQDQ diff --git a/arch/x86/lib/crc-pclmul-template.h b/arch/x86/lib/crc-pclmul-template.h new file mode 100644 index 000000000000..7b89f0edbc17 --- /dev/null +++ b/arch/x86/lib/crc-pclmul-template.h @@ -0,0 +1,81 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * Macros for accessing the [V]PCLMULQDQ-based CRC functions that are + * instantiated by crc-pclmul-template.S + * + * Copyright 2025 Google LLC + * + * Author: Eric Biggers + */ +#ifndef _CRC_PCLMUL_TEMPLATE_H +#define _CRC_PCLMUL_TEMPLATE_H + +#include +#include +#include +#include +#include "crc-pclmul-consts.h" + +#define DECLARE_CRC_PCLMUL_FUNCS(prefix, crc_t) \ +crc_t prefix##_pclmul_sse(crc_t crc, const u8 *p, size_t len, \ + const void *consts_ptr); \ +crc_t prefix##_vpclmul_avx2(crc_t crc, const u8 *p, size_t len, \ + const void *consts_ptr); \ +crc_t prefix##_vpclmul_avx10_256(crc_t crc, const u8 *p, size_t len, \ + const void *consts_ptr); \ +crc_t prefix##_vpclmul_avx10_512(crc_t crc, const u8 *p, size_t len, \ + const void *consts_ptr); \ +DEFINE_STATIC_CALL(prefix##_pclmul, prefix##_pclmul_sse) + +#define INIT_CRC_PCLMUL(prefix) \ +do { \ + if (IS_ENABLED(CONFIG_AS_VPCLMULQDQ) && \ + boot_cpu_has(X86_FEATURE_VPCLMULQDQ) && \ + boot_cpu_has(X86_FEATURE_AVX2) && \ + cpu_has_xfeatures(XFEATURE_MASK_YMM, NULL)) { \ + if (boot_cpu_has(X86_FEATURE_AVX512BW) && \ + boot_cpu_has(X86_FEATURE_AVX512VL) && \ + cpu_has_xfeatures(XFEATURE_MASK_AVX512, NULL)) { \ + if (boot_cpu_has(X86_FEATURE_PREFER_YMM)) \ + static_call_update(prefix##_pclmul, \ + prefix##_vpclmul_avx10_256); \ + else \ + static_call_update(prefix##_pclmul, \ + prefix##_vpclmul_avx10_512); \ + } else { \ + static_call_update(prefix##_pclmul, \ + prefix##_vpclmul_avx2); \ + } \ + } \ +} while (0) + +/* + * Call a [V]PCLMULQDQ optimized CRC function if the data length is at least 16 + * bytes, the CPU has PCLMULQDQ support, and the current context may use SIMD. + * + * 16 bytes is the minimum length supported by the [V]PCLMULQDQ functions. + * There is overhead associated with kernel_fpu_begin() and kernel_fpu_end(), + * varying by CPU and factors such as which parts of the "FPU" state userspace + * has touched, which could result in a larger cutoff being better. Indeed, a + * larger cutoff is usually better for a *single* message. However, the + * overhead of the FPU section gets amortized if multiple FPU sections get + * executed before returning to userspace, since the XSAVE and XRSTOR occur only + * once. Considering that and the fact that the [V]PCLMULQDQ code is lighter on + * the dcache than the table-based code is, a 16-byte cutoff seems to work well. + */ +#define CRC_PCLMUL(crc, p, len, prefix, consts, have_pclmulqdq) \ +do { \ + if ((len) >= 16 && static_branch_likely(&(have_pclmulqdq)) && \ + crypto_simd_usable()) { \ + const void *consts_ptr; \ + \ + consts_ptr = (consts).fold_across_128_bits_consts; \ + kernel_fpu_begin(); \ + crc = static_call(prefix##_pclmul)((crc), (p), (len), \ + consts_ptr); \ + kernel_fpu_end(); \ + return crc; \ + } \ +} while (0) + +#endif /* _CRC_PCLMUL_TEMPLATE_H */ diff --git a/arch/x86/lib/crc-t10dif-glue.c b/arch/x86/lib/crc-t10dif-glue.c index 13f07ddc9122..6b09374b8355 100644 --- a/arch/x86/lib/crc-t10dif-glue.c +++ b/arch/x86/lib/crc-t10dif-glue.c @@ -1,37 +1,32 @@ // SPDX-License-Identifier: GPL-2.0-or-later /* - * CRC-T10DIF using PCLMULQDQ instructions + * CRC-T10DIF using [V]PCLMULQDQ instructions * * Copyright 2024 Google LLC */ -#include -#include -#include #include #include +#include "crc-pclmul-template.h" static DEFINE_STATIC_KEY_FALSE(have_pclmulqdq); -asmlinkage u16 crc_t10dif_pcl(u16 init_crc, const u8 *buf, size_t len); +DECLARE_CRC_PCLMUL_FUNCS(crc16_msb, u16); u16 crc_t10dif_arch(u16 crc, const u8 *p, size_t len) { - if (len >= 16 && - static_key_enabled(&have_pclmulqdq) && crypto_simd_usable()) { - kernel_fpu_begin(); - crc = crc_t10dif_pcl(crc, p, len); - kernel_fpu_end(); - return crc; - } + CRC_PCLMUL(crc, p, len, crc16_msb, crc16_msb_0x8bb7_consts, + have_pclmulqdq); return crc_t10dif_generic(crc, p, len); } EXPORT_SYMBOL(crc_t10dif_arch); static int __init crc_t10dif_x86_init(void) { - if (boot_cpu_has(X86_FEATURE_PCLMULQDQ)) + if (boot_cpu_has(X86_FEATURE_PCLMULQDQ)) { static_branch_enable(&have_pclmulqdq); + INIT_CRC_PCLMUL(crc16_msb); + } return 0; } arch_initcall(crc_t10dif_x86_init); @@ -47,5 +42,5 @@ bool crc_t10dif_is_optimized(void) } EXPORT_SYMBOL(crc_t10dif_is_optimized); -MODULE_DESCRIPTION("CRC-T10DIF using PCLMULQDQ instructions"); +MODULE_DESCRIPTION("CRC-T10DIF using [V]PCLMULQDQ instructions"); MODULE_LICENSE("GPL"); diff --git a/arch/x86/lib/crc16-msb-pclmul.S b/arch/x86/lib/crc16-msb-pclmul.S new file mode 100644 index 000000000000..e9fe248093a8 --- /dev/null +++ b/arch/x86/lib/crc16-msb-pclmul.S @@ -0,0 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +// Copyright 2025 Google LLC + +#include "crc-pclmul-template.S" + +DEFINE_CRC_PCLMUL_FUNCS(crc16_msb, /* bits= */ 16, /* lsb= */ 0) diff --git a/arch/x86/lib/crc32-glue.c b/arch/x86/lib/crc32-glue.c index 2dd18a886ded..5b2878c2f793 100644 --- a/arch/x86/lib/crc32-glue.c +++ b/arch/x86/lib/crc32-glue.c @@ -7,43 +7,20 @@ * Copyright 2024 Google LLC */ -#include -#include -#include #include -#include #include - -/* minimum size of buffer for crc32_pclmul_le_16 */ -#define CRC32_PCLMUL_MIN_LEN 64 +#include "crc-pclmul-template.h" static DEFINE_STATIC_KEY_FALSE(have_crc32); static DEFINE_STATIC_KEY_FALSE(have_pclmulqdq); -u32 crc32_pclmul_le_16(u32 crc, const u8 *buffer, size_t len); +DECLARE_CRC_PCLMUL_FUNCS(crc32_lsb, u32); u32 crc32_le_arch(u32 crc, const u8 *p, size_t len) { - if (len >= CRC32_PCLMUL_MIN_LEN + 15 && - static_branch_likely(&have_pclmulqdq) && crypto_simd_usable()) { - size_t n = -(uintptr_t)p & 15; - - /* align p to 16-byte boundary */ - if (n) { - crc = crc32_le_base(crc, p, n); - p += n; - len -= n; - } - n = round_down(len, 16); - kernel_fpu_begin(); - crc = crc32_pclmul_le_16(crc, p, n); - kernel_fpu_end(); - p += n; - len -= n; - } - if (len) - crc = crc32_le_base(crc, p, len); - return crc; + CRC_PCLMUL(crc, p, len, crc32_lsb, crc32_lsb_0xedb88320_consts, + have_pclmulqdq); + return crc32_le_base(crc, p, len); } EXPORT_SYMBOL(crc32_le_arch); @@ -78,10 +55,18 @@ u32 crc32c_le_arch(u32 crc, const u8 *p, size_t len) for (num_longs = len / sizeof(unsigned long); num_longs != 0; num_longs--, p += sizeof(unsigned long)) - asm(CRC32_INST : "+r" (crc) : "rm" (*(unsigned long *)p)); + asm(CRC32_INST : "+r" (crc) : ASM_INPUT_RM (*(unsigned long *)p)); - for (len %= sizeof(unsigned long); len; len--, p++) - asm("crc32b %1, %0" : "+r" (crc) : "rm" (*p)); + if (sizeof(unsigned long) > 4 && (len & 4)) { + asm("crc32l %1, %0" : "+r" (crc) : ASM_INPUT_RM (*(u32 *)p)); + p += 4; + } + if (len & 2) { + asm("crc32w %1, %0" : "+r" (crc) : ASM_INPUT_RM (*(u16 *)p)); + p += 2; + } + if (len & 1) + asm("crc32b %1, %0" : "+r" (crc) : ASM_INPUT_RM (*p)); return crc; } @@ -97,8 +82,10 @@ static int __init crc32_x86_init(void) { if (boot_cpu_has(X86_FEATURE_XMM4_2)) static_branch_enable(&have_crc32); - if (boot_cpu_has(X86_FEATURE_PCLMULQDQ)) + if (boot_cpu_has(X86_FEATURE_PCLMULQDQ)) { static_branch_enable(&have_pclmulqdq); + INIT_CRC_PCLMUL(crc32_lsb); + } return 0; } arch_initcall(crc32_x86_init); diff --git a/arch/x86/lib/crc32-pclmul.S b/arch/x86/lib/crc32-pclmul.S index f9637789cac1..f20f40fb0172 100644 --- a/arch/x86/lib/crc32-pclmul.S +++ b/arch/x86/lib/crc32-pclmul.S @@ -1,217 +1,6 @@ -/* SPDX-License-Identifier: GPL-2.0-only */ -/* - * Copyright 2012 Xyratex Technology Limited - * - * Using hardware provided PCLMULQDQ instruction to accelerate the CRC32 - * calculation. - * CRC32 polynomial:0x04c11db7(BE)/0xEDB88320(LE) - * PCLMULQDQ is a new instruction in Intel SSE4.2, the reference can be found - * at: - * http://www.intel.com/products/processor/manuals/ - * Intel(R) 64 and IA-32 Architectures Software Developer's Manual - * Volume 2B: Instruction Set Reference, N-Z - * - * Authors: Gregory Prestas - * Alexander Boyko - */ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +// Copyright 2025 Google LLC -#include +#include "crc-pclmul-template.S" - -.section .rodata -.align 16 -/* - * [x4*128+32 mod P(x) << 32)]' << 1 = 0x154442bd4 - * #define CONSTANT_R1 0x154442bd4LL - * - * [(x4*128-32 mod P(x) << 32)]' << 1 = 0x1c6e41596 - * #define CONSTANT_R2 0x1c6e41596LL - */ -.Lconstant_R2R1: - .octa 0x00000001c6e415960000000154442bd4 -/* - * [(x128+32 mod P(x) << 32)]' << 1 = 0x1751997d0 - * #define CONSTANT_R3 0x1751997d0LL - * - * [(x128-32 mod P(x) << 32)]' << 1 = 0x0ccaa009e - * #define CONSTANT_R4 0x0ccaa009eLL - */ -.Lconstant_R4R3: - .octa 0x00000000ccaa009e00000001751997d0 -/* - * [(x64 mod P(x) << 32)]' << 1 = 0x163cd6124 - * #define CONSTANT_R5 0x163cd6124LL - */ -.Lconstant_R5: - .octa 0x00000000000000000000000163cd6124 -.Lconstant_mask32: - .octa 0x000000000000000000000000FFFFFFFF -/* - * #define CRCPOLY_TRUE_LE_FULL 0x1DB710641LL - * - * Barrett Reduction constant (u64`) = u` = (x**64 / P(x))` = 0x1F7011641LL - * #define CONSTANT_RU 0x1F7011641LL - */ -.Lconstant_RUpoly: - .octa 0x00000001F701164100000001DB710641 - -#define CONSTANT %xmm0 - -#ifdef __x86_64__ -#define CRC %edi -#define BUF %rsi -#define LEN %rdx -#else -#define CRC %eax -#define BUF %edx -#define LEN %ecx -#endif - - - -.text -/** - * Calculate crc32 - * CRC - initial crc32 - * BUF - buffer (16 bytes aligned) - * LEN - sizeof buffer (16 bytes aligned), LEN should be greater than 63 - * return %eax crc32 - * u32 crc32_pclmul_le_16(u32 crc, const u8 *buffer, size_t len); - */ - -SYM_FUNC_START(crc32_pclmul_le_16) /* buffer and buffer size are 16 bytes aligned */ - movdqa (BUF), %xmm1 - movdqa 0x10(BUF), %xmm2 - movdqa 0x20(BUF), %xmm3 - movdqa 0x30(BUF), %xmm4 - movd CRC, CONSTANT - pxor CONSTANT, %xmm1 - sub $0x40, LEN - add $0x40, BUF - cmp $0x40, LEN - jb .Lless_64 - -#ifdef __x86_64__ - movdqa .Lconstant_R2R1(%rip), CONSTANT -#else - movdqa .Lconstant_R2R1, CONSTANT -#endif - -.Lloop_64:/* 64 bytes Full cache line folding */ - prefetchnta 0x40(BUF) - movdqa %xmm1, %xmm5 - movdqa %xmm2, %xmm6 - movdqa %xmm3, %xmm7 -#ifdef __x86_64__ - movdqa %xmm4, %xmm8 -#endif - pclmulqdq $0x00, CONSTANT, %xmm1 - pclmulqdq $0x00, CONSTANT, %xmm2 - pclmulqdq $0x00, CONSTANT, %xmm3 -#ifdef __x86_64__ - pclmulqdq $0x00, CONSTANT, %xmm4 -#endif - pclmulqdq $0x11, CONSTANT, %xmm5 - pclmulqdq $0x11, CONSTANT, %xmm6 - pclmulqdq $0x11, CONSTANT, %xmm7 -#ifdef __x86_64__ - pclmulqdq $0x11, CONSTANT, %xmm8 -#endif - pxor %xmm5, %xmm1 - pxor %xmm6, %xmm2 - pxor %xmm7, %xmm3 -#ifdef __x86_64__ - pxor %xmm8, %xmm4 -#else - /* xmm8 unsupported for x32 */ - movdqa %xmm4, %xmm5 - pclmulqdq $0x00, CONSTANT, %xmm4 - pclmulqdq $0x11, CONSTANT, %xmm5 - pxor %xmm5, %xmm4 -#endif - - pxor (BUF), %xmm1 - pxor 0x10(BUF), %xmm2 - pxor 0x20(BUF), %xmm3 - pxor 0x30(BUF), %xmm4 - - sub $0x40, LEN - add $0x40, BUF - cmp $0x40, LEN - jge .Lloop_64 -.Lless_64:/* Folding cache line into 128bit */ -#ifdef __x86_64__ - movdqa .Lconstant_R4R3(%rip), CONSTANT -#else - movdqa .Lconstant_R4R3, CONSTANT -#endif - prefetchnta (BUF) - - movdqa %xmm1, %xmm5 - pclmulqdq $0x00, CONSTANT, %xmm1 - pclmulqdq $0x11, CONSTANT, %xmm5 - pxor %xmm5, %xmm1 - pxor %xmm2, %xmm1 - - movdqa %xmm1, %xmm5 - pclmulqdq $0x00, CONSTANT, %xmm1 - pclmulqdq $0x11, CONSTANT, %xmm5 - pxor %xmm5, %xmm1 - pxor %xmm3, %xmm1 - - movdqa %xmm1, %xmm5 - pclmulqdq $0x00, CONSTANT, %xmm1 - pclmulqdq $0x11, CONSTANT, %xmm5 - pxor %xmm5, %xmm1 - pxor %xmm4, %xmm1 - - cmp $0x10, LEN - jb .Lfold_64 -.Lloop_16:/* Folding rest buffer into 128bit */ - movdqa %xmm1, %xmm5 - pclmulqdq $0x00, CONSTANT, %xmm1 - pclmulqdq $0x11, CONSTANT, %xmm5 - pxor %xmm5, %xmm1 - pxor (BUF), %xmm1 - sub $0x10, LEN - add $0x10, BUF - cmp $0x10, LEN - jge .Lloop_16 - -.Lfold_64: - /* perform the last 64 bit fold, also adds 32 zeroes - * to the input stream */ - pclmulqdq $0x01, %xmm1, CONSTANT /* R4 * xmm1.low */ - psrldq $0x08, %xmm1 - pxor CONSTANT, %xmm1 - - /* final 32-bit fold */ - movdqa %xmm1, %xmm2 -#ifdef __x86_64__ - movdqa .Lconstant_R5(%rip), CONSTANT - movdqa .Lconstant_mask32(%rip), %xmm3 -#else - movdqa .Lconstant_R5, CONSTANT - movdqa .Lconstant_mask32, %xmm3 -#endif - psrldq $0x04, %xmm2 - pand %xmm3, %xmm1 - pclmulqdq $0x00, CONSTANT, %xmm1 - pxor %xmm2, %xmm1 - - /* Finish up with the bit-reversed barrett reduction 64 ==> 32 bits */ -#ifdef __x86_64__ - movdqa .Lconstant_RUpoly(%rip), CONSTANT -#else - movdqa .Lconstant_RUpoly, CONSTANT -#endif - movdqa %xmm1, %xmm2 - pand %xmm3, %xmm1 - pclmulqdq $0x10, CONSTANT, %xmm1 - pand %xmm3, %xmm1 - pclmulqdq $0x00, CONSTANT, %xmm1 - pxor %xmm2, %xmm1 - pextrd $0x01, %xmm1, %eax - - RET -SYM_FUNC_END(crc32_pclmul_le_16) +DEFINE_CRC_PCLMUL_FUNCS(crc32_lsb, /* bits= */ 32, /* lsb= */ 1) diff --git a/arch/x86/lib/crct10dif-pcl-asm_64.S b/arch/x86/lib/crct10dif-pcl-asm_64.S deleted file mode 100644 index 5286db5b8165..000000000000 --- a/arch/x86/lib/crct10dif-pcl-asm_64.S +++ /dev/null @@ -1,332 +0,0 @@ -######################################################################## -# Implement fast CRC-T10DIF computation with SSE and PCLMULQDQ instructions -# -# Copyright (c) 2013, Intel Corporation -# -# Authors: -# Erdinc Ozturk -# Vinodh Gopal -# James Guilford -# Tim Chen -# -# This software is available to you under a choice of one of two -# licenses. You may choose to be licensed under the terms of the GNU -# General Public License (GPL) Version 2, available from the file -# COPYING in the main directory of this source tree, or the -# OpenIB.org BSD license below: -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the -# distribution. -# -# * Neither the name of the Intel Corporation nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# -# THIS SOFTWARE IS PROVIDED BY INTEL CORPORATION ""AS IS"" AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL INTEL CORPORATION OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -# Reference paper titled "Fast CRC Computation for Generic -# Polynomials Using PCLMULQDQ Instruction" -# URL: http://www.intel.com/content/dam/www/public/us/en/documents -# /white-papers/fast-crc-computation-generic-polynomials-pclmulqdq-paper.pdf -# - -#include - -.text - -#define init_crc %edi -#define buf %rsi -#define len %rdx - -#define FOLD_CONSTS %xmm10 -#define BSWAP_MASK %xmm11 - -# Fold reg1, reg2 into the next 32 data bytes, storing the result back into -# reg1, reg2. -.macro fold_32_bytes offset, reg1, reg2 - movdqu \offset(buf), %xmm9 - movdqu \offset+16(buf), %xmm12 - pshufb BSWAP_MASK, %xmm9 - pshufb BSWAP_MASK, %xmm12 - movdqa \reg1, %xmm8 - movdqa \reg2, %xmm13 - pclmulqdq $0x00, FOLD_CONSTS, \reg1 - pclmulqdq $0x11, FOLD_CONSTS, %xmm8 - pclmulqdq $0x00, FOLD_CONSTS, \reg2 - pclmulqdq $0x11, FOLD_CONSTS, %xmm13 - pxor %xmm9 , \reg1 - xorps %xmm8 , \reg1 - pxor %xmm12, \reg2 - xorps %xmm13, \reg2 -.endm - -# Fold src_reg into dst_reg. -.macro fold_16_bytes src_reg, dst_reg - movdqa \src_reg, %xmm8 - pclmulqdq $0x11, FOLD_CONSTS, \src_reg - pclmulqdq $0x00, FOLD_CONSTS, %xmm8 - pxor %xmm8, \dst_reg - xorps \src_reg, \dst_reg -.endm - -# -# u16 crc_t10dif_pcl(u16 init_crc, const *u8 buf, size_t len); -# -# Assumes len >= 16. -# -SYM_FUNC_START(crc_t10dif_pcl) - - movdqa .Lbswap_mask(%rip), BSWAP_MASK - - # For sizes less than 256 bytes, we can't fold 128 bytes at a time. - cmp $256, len - jl .Lless_than_256_bytes - - # Load the first 128 data bytes. Byte swapping is necessary to make the - # bit order match the polynomial coefficient order. - movdqu 16*0(buf), %xmm0 - movdqu 16*1(buf), %xmm1 - movdqu 16*2(buf), %xmm2 - movdqu 16*3(buf), %xmm3 - movdqu 16*4(buf), %xmm4 - movdqu 16*5(buf), %xmm5 - movdqu 16*6(buf), %xmm6 - movdqu 16*7(buf), %xmm7 - add $128, buf - pshufb BSWAP_MASK, %xmm0 - pshufb BSWAP_MASK, %xmm1 - pshufb BSWAP_MASK, %xmm2 - pshufb BSWAP_MASK, %xmm3 - pshufb BSWAP_MASK, %xmm4 - pshufb BSWAP_MASK, %xmm5 - pshufb BSWAP_MASK, %xmm6 - pshufb BSWAP_MASK, %xmm7 - - # XOR the first 16 data *bits* with the initial CRC value. - pxor %xmm8, %xmm8 - pinsrw $7, init_crc, %xmm8 - pxor %xmm8, %xmm0 - - movdqa .Lfold_across_128_bytes_consts(%rip), FOLD_CONSTS - - # Subtract 128 for the 128 data bytes just consumed. Subtract another - # 128 to simplify the termination condition of the following loop. - sub $256, len - - # While >= 128 data bytes remain (not counting xmm0-7), fold the 128 - # bytes xmm0-7 into them, storing the result back into xmm0-7. -.Lfold_128_bytes_loop: - fold_32_bytes 0, %xmm0, %xmm1 - fold_32_bytes 32, %xmm2, %xmm3 - fold_32_bytes 64, %xmm4, %xmm5 - fold_32_bytes 96, %xmm6, %xmm7 - add $128, buf - sub $128, len - jge .Lfold_128_bytes_loop - - # Now fold the 112 bytes in xmm0-xmm6 into the 16 bytes in xmm7. - - # Fold across 64 bytes. - movdqa .Lfold_across_64_bytes_consts(%rip), FOLD_CONSTS - fold_16_bytes %xmm0, %xmm4 - fold_16_bytes %xmm1, %xmm5 - fold_16_bytes %xmm2, %xmm6 - fold_16_bytes %xmm3, %xmm7 - # Fold across 32 bytes. - movdqa .Lfold_across_32_bytes_consts(%rip), FOLD_CONSTS - fold_16_bytes %xmm4, %xmm6 - fold_16_bytes %xmm5, %xmm7 - # Fold across 16 bytes. - movdqa .Lfold_across_16_bytes_consts(%rip), FOLD_CONSTS - fold_16_bytes %xmm6, %xmm7 - - # Add 128 to get the correct number of data bytes remaining in 0...127 - # (not counting xmm7), following the previous extra subtraction by 128. - # Then subtract 16 to simplify the termination condition of the - # following loop. - add $128-16, len - - # While >= 16 data bytes remain (not counting xmm7), fold the 16 bytes - # xmm7 into them, storing the result back into xmm7. - jl .Lfold_16_bytes_loop_done -.Lfold_16_bytes_loop: - movdqa %xmm7, %xmm8 - pclmulqdq $0x11, FOLD_CONSTS, %xmm7 - pclmulqdq $0x00, FOLD_CONSTS, %xmm8 - pxor %xmm8, %xmm7 - movdqu (buf), %xmm0 - pshufb BSWAP_MASK, %xmm0 - pxor %xmm0 , %xmm7 - add $16, buf - sub $16, len - jge .Lfold_16_bytes_loop - -.Lfold_16_bytes_loop_done: - # Add 16 to get the correct number of data bytes remaining in 0...15 - # (not counting xmm7), following the previous extra subtraction by 16. - add $16, len - je .Lreduce_final_16_bytes - -.Lhandle_partial_segment: - # Reduce the last '16 + len' bytes where 1 <= len <= 15 and the first 16 - # bytes are in xmm7 and the rest are the remaining data in 'buf'. To do - # this without needing a fold constant for each possible 'len', redivide - # the bytes into a first chunk of 'len' bytes and a second chunk of 16 - # bytes, then fold the first chunk into the second. - - movdqa %xmm7, %xmm2 - - # xmm1 = last 16 original data bytes - movdqu -16(buf, len), %xmm1 - pshufb BSWAP_MASK, %xmm1 - - # xmm2 = high order part of second chunk: xmm7 left-shifted by 'len' bytes. - lea .Lbyteshift_table+16(%rip), %rax - sub len, %rax - movdqu (%rax), %xmm0 - pshufb %xmm0, %xmm2 - - # xmm7 = first chunk: xmm7 right-shifted by '16-len' bytes. - pxor .Lmask1(%rip), %xmm0 - pshufb %xmm0, %xmm7 - - # xmm1 = second chunk: 'len' bytes from xmm1 (low-order bytes), - # then '16-len' bytes from xmm2 (high-order bytes). - pblendvb %xmm2, %xmm1 #xmm0 is implicit - - # Fold the first chunk into the second chunk, storing the result in xmm7. - movdqa %xmm7, %xmm8 - pclmulqdq $0x11, FOLD_CONSTS, %xmm7 - pclmulqdq $0x00, FOLD_CONSTS, %xmm8 - pxor %xmm8, %xmm7 - pxor %xmm1, %xmm7 - -.Lreduce_final_16_bytes: - # Reduce the 128-bit value M(x), stored in xmm7, to the final 16-bit CRC - - # Load 'x^48 * (x^48 mod G(x))' and 'x^48 * (x^80 mod G(x))'. - movdqa .Lfinal_fold_consts(%rip), FOLD_CONSTS - - # Fold the high 64 bits into the low 64 bits, while also multiplying by - # x^64. This produces a 128-bit value congruent to x^64 * M(x) and - # whose low 48 bits are 0. - movdqa %xmm7, %xmm0 - pclmulqdq $0x11, FOLD_CONSTS, %xmm7 # high bits * x^48 * (x^80 mod G(x)) - pslldq $8, %xmm0 - pxor %xmm0, %xmm7 # + low bits * x^64 - - # Fold the high 32 bits into the low 96 bits. This produces a 96-bit - # value congruent to x^64 * M(x) and whose low 48 bits are 0. - movdqa %xmm7, %xmm0 - pand .Lmask2(%rip), %xmm0 # zero high 32 bits - psrldq $12, %xmm7 # extract high 32 bits - pclmulqdq $0x00, FOLD_CONSTS, %xmm7 # high 32 bits * x^48 * (x^48 mod G(x)) - pxor %xmm0, %xmm7 # + low bits - - # Load G(x) and floor(x^48 / G(x)). - movdqa .Lbarrett_reduction_consts(%rip), FOLD_CONSTS - - # Use Barrett reduction to compute the final CRC value. - movdqa %xmm7, %xmm0 - pclmulqdq $0x11, FOLD_CONSTS, %xmm7 # high 32 bits * floor(x^48 / G(x)) - psrlq $32, %xmm7 # /= x^32 - pclmulqdq $0x00, FOLD_CONSTS, %xmm7 # *= G(x) - psrlq $48, %xmm0 - pxor %xmm7, %xmm0 # + low 16 nonzero bits - # Final CRC value (x^16 * M(x)) mod G(x) is in low 16 bits of xmm0. - - pextrw $0, %xmm0, %eax - RET - -.align 16 -.Lless_than_256_bytes: - # Checksumming a buffer of length 16...255 bytes - - # Load the first 16 data bytes. - movdqu (buf), %xmm7 - pshufb BSWAP_MASK, %xmm7 - add $16, buf - - # XOR the first 16 data *bits* with the initial CRC value. - pxor %xmm0, %xmm0 - pinsrw $7, init_crc, %xmm0 - pxor %xmm0, %xmm7 - - movdqa .Lfold_across_16_bytes_consts(%rip), FOLD_CONSTS - cmp $16, len - je .Lreduce_final_16_bytes # len == 16 - sub $32, len - jge .Lfold_16_bytes_loop # 32 <= len <= 255 - add $16, len - jmp .Lhandle_partial_segment # 17 <= len <= 31 -SYM_FUNC_END(crc_t10dif_pcl) - -.section .rodata, "a", @progbits -.align 16 - -# Fold constants precomputed from the polynomial 0x18bb7 -# G(x) = x^16 + x^15 + x^11 + x^9 + x^8 + x^7 + x^5 + x^4 + x^2 + x^1 + x^0 -.Lfold_across_128_bytes_consts: - .quad 0x0000000000006123 # x^(8*128) mod G(x) - .quad 0x0000000000002295 # x^(8*128+64) mod G(x) -.Lfold_across_64_bytes_consts: - .quad 0x0000000000001069 # x^(4*128) mod G(x) - .quad 0x000000000000dd31 # x^(4*128+64) mod G(x) -.Lfold_across_32_bytes_consts: - .quad 0x000000000000857d # x^(2*128) mod G(x) - .quad 0x0000000000007acc # x^(2*128+64) mod G(x) -.Lfold_across_16_bytes_consts: - .quad 0x000000000000a010 # x^(1*128) mod G(x) - .quad 0x0000000000001faa # x^(1*128+64) mod G(x) -.Lfinal_fold_consts: - .quad 0x1368000000000000 # x^48 * (x^48 mod G(x)) - .quad 0x2d56000000000000 # x^48 * (x^80 mod G(x)) -.Lbarrett_reduction_consts: - .quad 0x0000000000018bb7 # G(x) - .quad 0x00000001f65a57f8 # floor(x^48 / G(x)) - -.section .rodata.cst16.mask1, "aM", @progbits, 16 -.align 16 -.Lmask1: - .octa 0x80808080808080808080808080808080 - -.section .rodata.cst16.mask2, "aM", @progbits, 16 -.align 16 -.Lmask2: - .octa 0x00000000FFFFFFFFFFFFFFFFFFFFFFFF - -.section .rodata.cst16.bswap_mask, "aM", @progbits, 16 -.align 16 -.Lbswap_mask: - .octa 0x000102030405060708090A0B0C0D0E0F - -.section .rodata.cst32.byteshift_table, "aM", @progbits, 32 -.align 16 -# For 1 <= len <= 15, the 16-byte vector beginning at &byteshift_table[16 - len] -# is the index vector to shift left by 'len' bytes, and is also {0x80, ..., -# 0x80} XOR the index vector to shift right by '16 - len' bytes. -.Lbyteshift_table: - .byte 0x0, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87 - .byte 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f - .byte 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7 - .byte 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe , 0x0 diff --git a/drivers/nvme/host/Kconfig b/drivers/nvme/host/Kconfig index 486afe598184..0e1a554c596d 100644 --- a/drivers/nvme/host/Kconfig +++ b/drivers/nvme/host/Kconfig @@ -80,8 +80,7 @@ config NVME_TCP depends on INET depends on BLOCK select NVME_FABRICS - select CRYPTO - select CRYPTO_CRC32C + select CRC32 help This provides support for the NVMe over Fabrics protocol using the TCP transport. This allows you to use remote block devices diff --git a/drivers/nvme/host/tcp.c b/drivers/nvme/host/tcp.c index 327f3f2f5399..21b56a378664 100644 --- a/drivers/nvme/host/tcp.c +++ b/drivers/nvme/host/tcp.c @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -17,7 +18,6 @@ #include #include #include -#include #include #include @@ -169,8 +169,8 @@ struct nvme_tcp_queue { bool hdr_digest; bool data_digest; bool tls_enabled; - struct ahash_request *rcv_hash; - struct ahash_request *snd_hash; + u32 rcv_crc; + u32 snd_crc; __le32 exp_ddgst; __le32 recv_ddgst; struct completion tls_complete; @@ -457,32 +457,29 @@ nvme_tcp_fetch_request(struct nvme_tcp_queue *queue) return req; } -static inline void nvme_tcp_ddgst_final(struct ahash_request *hash, - __le32 *dgst) +static inline void nvme_tcp_ddgst_init(u32 *crcp) { - ahash_request_set_crypt(hash, NULL, (u8 *)dgst, 0); - crypto_ahash_final(hash); + *crcp = ~0; } -static inline void nvme_tcp_ddgst_update(struct ahash_request *hash, +static inline void nvme_tcp_ddgst_update(u32 *crcp, struct page *page, off_t off, size_t len) { - struct scatterlist sg; + const void *virt = kmap_local_page(page + (off >> PAGE_SHIFT)); - sg_init_table(&sg, 1); - sg_set_page(&sg, page, len, off); - ahash_request_set_crypt(hash, &sg, NULL, len); - crypto_ahash_update(hash); + *crcp = crc32c(*crcp, virt + (off & ~PAGE_MASK), len); + + kunmap_local(virt); } -static inline void nvme_tcp_hdgst(struct ahash_request *hash, - void *pdu, size_t len) +static inline void nvme_tcp_ddgst_final(u32 *crcp, __le32 *dgst) { - struct scatterlist sg; + *dgst = cpu_to_le32(~*crcp); +} - sg_init_one(&sg, pdu, len); - ahash_request_set_crypt(hash, &sg, pdu + len, len); - crypto_ahash_digest(hash); +static inline void nvme_tcp_hdgst(void *pdu, size_t len) +{ + put_unaligned_le32(~crc32c(~0, pdu, len), pdu + len); } static int nvme_tcp_verify_hdgst(struct nvme_tcp_queue *queue, @@ -500,7 +497,7 @@ static int nvme_tcp_verify_hdgst(struct nvme_tcp_queue *queue, } recv_digest = *(__le32 *)(pdu + hdr->hlen); - nvme_tcp_hdgst(queue->rcv_hash, pdu, pdu_len); + nvme_tcp_hdgst(pdu, pdu_len); exp_digest = *(__le32 *)(pdu + hdr->hlen); if (recv_digest != exp_digest) { dev_err(queue->ctrl->ctrl.device, @@ -527,7 +524,7 @@ static int nvme_tcp_check_ddgst(struct nvme_tcp_queue *queue, void *pdu) nvme_tcp_queue_id(queue)); return -EPROTO; } - crypto_ahash_init(queue->rcv_hash); + nvme_tcp_ddgst_init(&queue->rcv_crc); return 0; } @@ -890,6 +887,17 @@ static inline void nvme_tcp_end_request(struct request *rq, u16 status) nvme_complete_rq(rq); } +static size_t crc_and_copy_to_iter(const void *addr, size_t bytes, void *crcp_, + struct iov_iter *i) +{ + u32 *crcp = crcp_; + size_t copied; + + copied = copy_to_iter(addr, bytes, i); + *crcp = crc32c(*crcp, addr, copied); + return copied; +} + static int nvme_tcp_recv_data(struct nvme_tcp_queue *queue, struct sk_buff *skb, unsigned int *offset, size_t *len) { @@ -927,8 +935,10 @@ static int nvme_tcp_recv_data(struct nvme_tcp_queue *queue, struct sk_buff *skb, iov_iter_count(&req->iter)); if (queue->data_digest) - ret = skb_copy_and_hash_datagram_iter(skb, *offset, - &req->iter, recv_len, queue->rcv_hash); + ret = __skb_datagram_iter(skb, *offset, &req->iter, + recv_len, true, + crc_and_copy_to_iter, + &queue->rcv_crc); else ret = skb_copy_datagram_iter(skb, *offset, &req->iter, recv_len); @@ -946,7 +956,8 @@ static int nvme_tcp_recv_data(struct nvme_tcp_queue *queue, struct sk_buff *skb, if (!queue->data_remaining) { if (queue->data_digest) { - nvme_tcp_ddgst_final(queue->rcv_hash, &queue->exp_ddgst); + nvme_tcp_ddgst_final(&queue->rcv_crc, + &queue->exp_ddgst); queue->ddgst_remaining = NVME_TCP_DIGEST_LENGTH; } else { if (pdu->hdr.flags & NVME_TCP_F_DATA_SUCCESS) { @@ -1148,7 +1159,7 @@ static int nvme_tcp_try_send_data(struct nvme_tcp_request *req) return ret; if (queue->data_digest) - nvme_tcp_ddgst_update(queue->snd_hash, page, + nvme_tcp_ddgst_update(&queue->snd_crc, page, offset, ret); /* @@ -1162,7 +1173,7 @@ static int nvme_tcp_try_send_data(struct nvme_tcp_request *req) /* fully successful last send in current PDU */ if (last && ret == len) { if (queue->data_digest) { - nvme_tcp_ddgst_final(queue->snd_hash, + nvme_tcp_ddgst_final(&queue->snd_crc, &req->ddgst); req->state = NVME_TCP_SEND_DDGST; req->offset = 0; @@ -1195,7 +1206,7 @@ static int nvme_tcp_try_send_cmd_pdu(struct nvme_tcp_request *req) msg.msg_flags |= MSG_EOR; if (queue->hdr_digest && !req->offset) - nvme_tcp_hdgst(queue->snd_hash, pdu, sizeof(*pdu)); + nvme_tcp_hdgst(pdu, sizeof(*pdu)); bvec_set_virt(&bvec, (void *)pdu + req->offset, len); iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, &bvec, 1, len); @@ -1208,7 +1219,7 @@ static int nvme_tcp_try_send_cmd_pdu(struct nvme_tcp_request *req) if (inline_data) { req->state = NVME_TCP_SEND_DATA; if (queue->data_digest) - crypto_ahash_init(queue->snd_hash); + nvme_tcp_ddgst_init(&queue->snd_crc); } else { nvme_tcp_done_send_req(queue); } @@ -1230,7 +1241,7 @@ static int nvme_tcp_try_send_data_pdu(struct nvme_tcp_request *req) int ret; if (queue->hdr_digest && !req->offset) - nvme_tcp_hdgst(queue->snd_hash, pdu, sizeof(*pdu)); + nvme_tcp_hdgst(pdu, sizeof(*pdu)); if (!req->h2cdata_left) msg.msg_flags |= MSG_SPLICE_PAGES; @@ -1245,7 +1256,7 @@ static int nvme_tcp_try_send_data_pdu(struct nvme_tcp_request *req) if (!len) { req->state = NVME_TCP_SEND_DATA; if (queue->data_digest) - crypto_ahash_init(queue->snd_hash); + nvme_tcp_ddgst_init(&queue->snd_crc); return 1; } req->offset += ret; @@ -1385,41 +1396,6 @@ static void nvme_tcp_io_work(struct work_struct *w) queue_work_on(queue->io_cpu, nvme_tcp_wq, &queue->io_work); } -static void nvme_tcp_free_crypto(struct nvme_tcp_queue *queue) -{ - struct crypto_ahash *tfm = crypto_ahash_reqtfm(queue->rcv_hash); - - ahash_request_free(queue->rcv_hash); - ahash_request_free(queue->snd_hash); - crypto_free_ahash(tfm); -} - -static int nvme_tcp_alloc_crypto(struct nvme_tcp_queue *queue) -{ - struct crypto_ahash *tfm; - - tfm = crypto_alloc_ahash("crc32c", 0, CRYPTO_ALG_ASYNC); - if (IS_ERR(tfm)) - return PTR_ERR(tfm); - - queue->snd_hash = ahash_request_alloc(tfm, GFP_KERNEL); - if (!queue->snd_hash) - goto free_tfm; - ahash_request_set_callback(queue->snd_hash, 0, NULL, NULL); - - queue->rcv_hash = ahash_request_alloc(tfm, GFP_KERNEL); - if (!queue->rcv_hash) - goto free_snd_hash; - ahash_request_set_callback(queue->rcv_hash, 0, NULL, NULL); - - return 0; -free_snd_hash: - ahash_request_free(queue->snd_hash); -free_tfm: - crypto_free_ahash(tfm); - return -ENOMEM; -} - static void nvme_tcp_free_async_req(struct nvme_tcp_ctrl *ctrl) { struct nvme_tcp_request *async = &ctrl->async_req; @@ -1452,9 +1428,6 @@ static void nvme_tcp_free_queue(struct nvme_ctrl *nctrl, int qid) if (!test_and_clear_bit(NVME_TCP_Q_ALLOCATED, &queue->flags)) return; - if (queue->hdr_digest || queue->data_digest) - nvme_tcp_free_crypto(queue); - page_frag_cache_drain(&queue->pf_cache); noreclaim_flag = memalloc_noreclaim_save(); @@ -1865,21 +1838,13 @@ static int nvme_tcp_alloc_queue(struct nvme_ctrl *nctrl, int qid, queue->hdr_digest = nctrl->opts->hdr_digest; queue->data_digest = nctrl->opts->data_digest; - if (queue->hdr_digest || queue->data_digest) { - ret = nvme_tcp_alloc_crypto(queue); - if (ret) { - dev_err(nctrl->device, - "failed to allocate queue %d crypto\n", qid); - goto err_sock; - } - } rcv_pdu_size = sizeof(struct nvme_tcp_rsp_pdu) + nvme_tcp_hdgst_len(queue); queue->pdu = kmalloc(rcv_pdu_size, GFP_KERNEL); if (!queue->pdu) { ret = -ENOMEM; - goto err_crypto; + goto err_sock; } dev_dbg(nctrl->device, "connecting queue %d\n", @@ -1912,9 +1877,6 @@ static int nvme_tcp_alloc_queue(struct nvme_ctrl *nctrl, int qid, kernel_sock_shutdown(queue->sock, SHUT_RDWR); err_rcv_pdu: kfree(queue->pdu); -err_crypto: - if (queue->hdr_digest || queue->data_digest) - nvme_tcp_free_crypto(queue); err_sock: /* ->sock will be released by fput() */ fput(queue->sock->file); diff --git a/drivers/nvme/target/tcp.c b/drivers/nvme/target/tcp.c index 4f9cac8a5abe..cbedf61c8d0a 100644 --- a/drivers/nvme/target/tcp.c +++ b/drivers/nvme/target/tcp.c @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -18,7 +19,6 @@ #include #include #include -#include #include #include "nvmet.h" @@ -173,8 +173,8 @@ struct nvmet_tcp_queue { /* digest state */ bool hdr_digest; bool data_digest; - struct ahash_request *snd_hash; - struct ahash_request *rcv_hash; + u32 snd_crc; + u32 rcv_crc; /* TLS state */ key_serial_t tls_pskid; @@ -295,14 +295,9 @@ static inline u8 nvmet_tcp_ddgst_len(struct nvmet_tcp_queue *queue) return queue->data_digest ? NVME_TCP_DIGEST_LENGTH : 0; } -static inline void nvmet_tcp_hdgst(struct ahash_request *hash, - void *pdu, size_t len) +static inline void nvmet_tcp_hdgst(void *pdu, size_t len) { - struct scatterlist sg; - - sg_init_one(&sg, pdu, len); - ahash_request_set_crypt(hash, &sg, pdu + len, len); - crypto_ahash_digest(hash); + put_unaligned_le32(~crc32c(~0, pdu, len), pdu + len); } static int nvmet_tcp_verify_hdgst(struct nvmet_tcp_queue *queue, @@ -319,7 +314,7 @@ static int nvmet_tcp_verify_hdgst(struct nvmet_tcp_queue *queue, } recv_digest = *(__le32 *)(pdu + hdr->hlen); - nvmet_tcp_hdgst(queue->rcv_hash, pdu, len); + nvmet_tcp_hdgst(pdu, len); exp_digest = *(__le32 *)(pdu + hdr->hlen); if (recv_digest != exp_digest) { pr_err("queue %d: header digest error: recv %#x expected %#x\n", @@ -442,12 +437,20 @@ static int nvmet_tcp_map_data(struct nvmet_tcp_cmd *cmd) return NVME_SC_INTERNAL; } -static void nvmet_tcp_calc_ddgst(struct ahash_request *hash, - struct nvmet_tcp_cmd *cmd) +static void nvmet_tcp_calc_ddgst(struct nvmet_tcp_cmd *cmd) { - ahash_request_set_crypt(hash, cmd->req.sg, - (void *)&cmd->exp_ddgst, cmd->req.transfer_len); - crypto_ahash_digest(hash); + size_t total_len = cmd->req.transfer_len; + struct scatterlist *sg = cmd->req.sg; + u32 crc = ~0; + + while (total_len) { + size_t len = min_t(size_t, total_len, sg->length); + + crc = crc32c(crc, sg_virt(sg), len); + total_len -= len; + sg = sg_next(sg); + } + cmd->exp_ddgst = cpu_to_le32(~crc); } static void nvmet_setup_c2h_data_pdu(struct nvmet_tcp_cmd *cmd) @@ -474,19 +477,18 @@ static void nvmet_setup_c2h_data_pdu(struct nvmet_tcp_cmd *cmd) if (queue->data_digest) { pdu->hdr.flags |= NVME_TCP_F_DDGST; - nvmet_tcp_calc_ddgst(queue->snd_hash, cmd); + nvmet_tcp_calc_ddgst(cmd); } if (cmd->queue->hdr_digest) { pdu->hdr.flags |= NVME_TCP_F_HDGST; - nvmet_tcp_hdgst(queue->snd_hash, pdu, sizeof(*pdu)); + nvmet_tcp_hdgst(pdu, sizeof(*pdu)); } } static void nvmet_setup_r2t_pdu(struct nvmet_tcp_cmd *cmd) { struct nvme_tcp_r2t_pdu *pdu = cmd->r2t_pdu; - struct nvmet_tcp_queue *queue = cmd->queue; u8 hdgst = nvmet_tcp_hdgst_len(cmd->queue); cmd->offset = 0; @@ -504,14 +506,13 @@ static void nvmet_setup_r2t_pdu(struct nvmet_tcp_cmd *cmd) pdu->r2t_offset = cpu_to_le32(cmd->rbytes_done); if (cmd->queue->hdr_digest) { pdu->hdr.flags |= NVME_TCP_F_HDGST; - nvmet_tcp_hdgst(queue->snd_hash, pdu, sizeof(*pdu)); + nvmet_tcp_hdgst(pdu, sizeof(*pdu)); } } static void nvmet_setup_response_pdu(struct nvmet_tcp_cmd *cmd) { struct nvme_tcp_rsp_pdu *pdu = cmd->rsp_pdu; - struct nvmet_tcp_queue *queue = cmd->queue; u8 hdgst = nvmet_tcp_hdgst_len(cmd->queue); cmd->offset = 0; @@ -524,7 +525,7 @@ static void nvmet_setup_response_pdu(struct nvmet_tcp_cmd *cmd) pdu->hdr.plen = cpu_to_le32(pdu->hdr.hlen + hdgst); if (cmd->queue->hdr_digest) { pdu->hdr.flags |= NVME_TCP_F_HDGST; - nvmet_tcp_hdgst(queue->snd_hash, pdu, sizeof(*pdu)); + nvmet_tcp_hdgst(pdu, sizeof(*pdu)); } } @@ -858,42 +859,6 @@ static void nvmet_prepare_receive_pdu(struct nvmet_tcp_queue *queue) smp_store_release(&queue->rcv_state, NVMET_TCP_RECV_PDU); } -static void nvmet_tcp_free_crypto(struct nvmet_tcp_queue *queue) -{ - struct crypto_ahash *tfm = crypto_ahash_reqtfm(queue->rcv_hash); - - ahash_request_free(queue->rcv_hash); - ahash_request_free(queue->snd_hash); - crypto_free_ahash(tfm); -} - -static int nvmet_tcp_alloc_crypto(struct nvmet_tcp_queue *queue) -{ - struct crypto_ahash *tfm; - - tfm = crypto_alloc_ahash("crc32c", 0, CRYPTO_ALG_ASYNC); - if (IS_ERR(tfm)) - return PTR_ERR(tfm); - - queue->snd_hash = ahash_request_alloc(tfm, GFP_KERNEL); - if (!queue->snd_hash) - goto free_tfm; - ahash_request_set_callback(queue->snd_hash, 0, NULL, NULL); - - queue->rcv_hash = ahash_request_alloc(tfm, GFP_KERNEL); - if (!queue->rcv_hash) - goto free_snd_hash; - ahash_request_set_callback(queue->rcv_hash, 0, NULL, NULL); - - return 0; -free_snd_hash: - ahash_request_free(queue->snd_hash); -free_tfm: - crypto_free_ahash(tfm); - return -ENOMEM; -} - - static int nvmet_tcp_handle_icreq(struct nvmet_tcp_queue *queue) { struct nvme_tcp_icreq_pdu *icreq = &queue->pdu.icreq; @@ -922,11 +887,6 @@ static int nvmet_tcp_handle_icreq(struct nvmet_tcp_queue *queue) queue->hdr_digest = !!(icreq->digest & NVME_TCP_HDR_DIGEST_ENABLE); queue->data_digest = !!(icreq->digest & NVME_TCP_DATA_DIGEST_ENABLE); - if (queue->hdr_digest || queue->data_digest) { - ret = nvmet_tcp_alloc_crypto(queue); - if (ret) - return ret; - } memset(icresp, 0, sizeof(*icresp)); icresp->hdr.type = nvme_tcp_icresp; @@ -1247,7 +1207,7 @@ static void nvmet_tcp_prep_recv_ddgst(struct nvmet_tcp_cmd *cmd) { struct nvmet_tcp_queue *queue = cmd->queue; - nvmet_tcp_calc_ddgst(queue->rcv_hash, cmd); + nvmet_tcp_calc_ddgst(cmd); queue->offset = 0; queue->left = NVME_TCP_DIGEST_LENGTH; queue->rcv_state = NVMET_TCP_RECV_DDGST; @@ -1616,8 +1576,6 @@ static void nvmet_tcp_release_queue_work(struct work_struct *w) /* ->sock will be released by fput() */ fput(queue->sock->file); nvmet_tcp_free_cmds(queue); - if (queue->hdr_digest || queue->data_digest) - nvmet_tcp_free_crypto(queue); ida_free(&nvmet_tcp_queue_ida, queue->idx); page_frag_cache_drain(&queue->pf_cache); kfree(queue); diff --git a/include/linux/skbuff.h b/include/linux/skbuff.h index bb2b751d274a..98804d51986c 100644 --- a/include/linux/skbuff.h +++ b/include/linux/skbuff.h @@ -4145,9 +4145,10 @@ static inline int skb_copy_datagram_msg(const struct sk_buff *from, int offset, } int skb_copy_and_csum_datagram_msg(struct sk_buff *skb, int hlen, struct msghdr *msg); -int skb_copy_and_hash_datagram_iter(const struct sk_buff *skb, int offset, - struct iov_iter *to, int len, - struct ahash_request *hash); +int __skb_datagram_iter(const struct sk_buff *skb, int offset, + struct iov_iter *to, int len, bool fault_short, + size_t (*cb)(const void *, size_t, void *, + struct iov_iter *), void *data); int skb_copy_datagram_from_iter(struct sk_buff *skb, int offset, struct iov_iter *from, int len); int zerocopy_sg_from_iter(struct sk_buff *skb, struct iov_iter *frm); diff --git a/net/core/datagram.c b/net/core/datagram.c index f0693707aece..19304c7ce7a3 100644 --- a/net/core/datagram.c +++ b/net/core/datagram.c @@ -61,7 +61,6 @@ #include #include #include -#include /* * Is a socket 'connection oriented' ? @@ -385,10 +384,10 @@ INDIRECT_CALLABLE_DECLARE(static size_t simple_copy_to_iter(const void *addr, void *data __always_unused, struct iov_iter *i)); -static int __skb_datagram_iter(const struct sk_buff *skb, int offset, - struct iov_iter *to, int len, bool fault_short, - size_t (*cb)(const void *, size_t, void *, - struct iov_iter *), void *data) +int __skb_datagram_iter(const struct sk_buff *skb, int offset, + struct iov_iter *to, int len, bool fault_short, + size_t (*cb)(const void *, size_t, void *, + struct iov_iter *), void *data) { int start = skb_headlen(skb); int i, copy = start - offset, start_off = offset, n; @@ -481,42 +480,7 @@ static int __skb_datagram_iter(const struct sk_buff *skb, int offset, return 0; } - -static size_t hash_and_copy_to_iter(const void *addr, size_t bytes, void *hashp, - struct iov_iter *i) -{ -#ifdef CONFIG_CRYPTO_HASH - struct ahash_request *hash = hashp; - struct scatterlist sg; - size_t copied; - - copied = copy_to_iter(addr, bytes, i); - sg_init_one(&sg, addr, copied); - ahash_request_set_crypt(hash, &sg, NULL, copied); - crypto_ahash_update(hash); - return copied; -#else - return 0; -#endif -} - -/** - * skb_copy_and_hash_datagram_iter - Copy datagram to an iovec iterator - * and update a hash. - * @skb: buffer to copy - * @offset: offset in the buffer to start copying from - * @to: iovec iterator to copy to - * @len: amount of data to copy from buffer to iovec - * @hash: hash request to update - */ -int skb_copy_and_hash_datagram_iter(const struct sk_buff *skb, int offset, - struct iov_iter *to, int len, - struct ahash_request *hash) -{ - return __skb_datagram_iter(skb, offset, to, len, true, - hash_and_copy_to_iter, hash); -} -EXPORT_SYMBOL(skb_copy_and_hash_datagram_iter); +EXPORT_SYMBOL_GPL(__skb_datagram_iter); static size_t simple_copy_to_iter(const void *addr, size_t bytes, void *data __always_unused, struct iov_iter *i) diff --git a/scripts/gen-crc-consts.py b/scripts/gen-crc-consts.py new file mode 100755 index 000000000000..aa678a50897d --- /dev/null +++ b/scripts/gen-crc-consts.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: GPL-2.0-or-later +# +# Script that generates constants for computing the given CRC variant(s). +# +# Copyright 2025 Google LLC +# +# Author: Eric Biggers + +import sys + +# XOR (add) an iterable of polynomials. +def xor(iterable): + res = 0 + for val in iterable: + res ^= val + return res + +# Multiply two polynomials. +def clmul(a, b): + return xor(a << i for i in range(b.bit_length()) if (b & (1 << i)) != 0) + +# Polynomial division floor(a / b). +def div(a, b): + q = 0 + while a.bit_length() >= b.bit_length(): + q ^= 1 << (a.bit_length() - b.bit_length()) + a ^= b << (a.bit_length() - b.bit_length()) + return q + +# Reduce the polynomial 'a' modulo the polynomial 'b'. +def reduce(a, b): + return a ^ clmul(div(a, b), b) + +# Reflect the bits of a polynomial. +def bitreflect(poly, num_bits): + assert poly.bit_length() <= num_bits + return xor(((poly >> i) & 1) << (num_bits - 1 - i) for i in range(num_bits)) + +# Format a polynomial as hex. Bit-reflect it if the CRC is lsb-first. +def fmt_poly(variant, poly, num_bits): + if variant.lsb: + poly = bitreflect(poly, num_bits) + return f'0x{poly:0{2*num_bits//8}x}' + +# Print a pair of 64-bit polynomial multipliers. They are always passed in the +# order [HI64_TERMS, LO64_TERMS] but will be printed in the appropriate order. +def print_mult_pair(variant, mults): + mults = list(mults if variant.lsb else reversed(mults)) + terms = ['HI64_TERMS', 'LO64_TERMS'] if variant.lsb else ['LO64_TERMS', 'HI64_TERMS'] + for i in range(2): + print(f'\t\t{fmt_poly(variant, mults[i]["val"], 64)},\t/* {terms[i]}: {mults[i]["desc"]} */') + +# Pretty-print a polynomial. +def pprint_poly(prefix, poly): + terms = [f'x^{i}' for i in reversed(range(poly.bit_length())) + if (poly & (1 << i)) != 0] + j = 0 + while j < len(terms): + s = prefix + terms[j] + (' +' if j < len(terms) - 1 else '') + j += 1 + while j < len(terms) and len(s) < 73: + s += ' ' + terms[j] + (' +' if j < len(terms) - 1 else '') + j += 1 + print(s) + prefix = ' * ' + (' ' * (len(prefix) - 3)) + +# Print a comment describing constants generated for the given CRC variant. +def print_header(variant, what): + print('/*') + s = f'{"least" if variant.lsb else "most"}-significant-bit-first CRC-{variant.bits}' + print(f' * {what} generated for {s} using') + pprint_poly(' * G(x) = ', variant.G) + print(' */') + +class CrcVariant: + def __init__(self, bits, generator_poly, bit_order): + self.bits = bits + if bit_order not in ['lsb', 'msb']: + raise ValueError('Invalid value for bit_order') + self.lsb = bit_order == 'lsb' + self.name = f'crc{bits}_{bit_order}_0x{generator_poly:0{(2*bits+7)//8}x}' + if self.lsb: + generator_poly = bitreflect(generator_poly, bits) + self.G = generator_poly ^ (1 << bits) + +# Generate tables for CRC computation using the "slice-by-N" method. +# N=1 corresponds to the traditional byte-at-a-time table. +def gen_slicebyN_tables(variants, n): + for v in variants: + print('') + print_header(v, f'Slice-by-{n} CRC table') + print(f'static const u{v.bits} __maybe_unused {v.name}_table[{256*n}] = {{') + s = '' + for i in range(256 * n): + # The i'th table entry is the CRC of the message consisting of byte + # i % 256 followed by i // 256 zero bytes. + poly = (bitreflect(i % 256, 8) if v.lsb else i % 256) << (v.bits + 8*(i//256)) + next_entry = fmt_poly(v, reduce(poly, v.G), v.bits) + ',' + if len(s + next_entry) > 71: + print(f'\t{s}') + s = '' + s += (' ' if s else '') + next_entry + if s: + print(f'\t{s}') + print('};') + +# Generate constants for carryless multiplication based CRC computation. +def gen_x86_pclmul_consts(variants): + # These are the distances, in bits, to generate folding constants for. + FOLD_DISTANCES = [2048, 1024, 512, 256, 128] + + for v in variants: + (G, n, lsb) = (v.G, v.bits, v.lsb) + print('') + print_header(v, 'CRC folding constants') + print('static const struct {') + if not lsb: + print('\tu8 bswap_mask[16];') + for i in FOLD_DISTANCES: + print(f'\tu64 fold_across_{i}_bits_consts[2];') + print('\tu8 shuf_table[48];') + print('\tu64 barrett_reduction_consts[2];') + print(f'}} {v.name}_consts ____cacheline_aligned __maybe_unused = {{') + + # Byte-reflection mask, needed for msb-first CRCs + if not lsb: + print('\t.bswap_mask = {' + ', '.join(str(i) for i in reversed(range(16))) + '},') + + # Fold constants for all distances down to 128 bits + for i in FOLD_DISTANCES: + print(f'\t.fold_across_{i}_bits_consts = {{') + # Given 64x64 => 128 bit carryless multiplication instructions, two + # 64-bit fold constants are needed per "fold distance" i: one for + # HI64_TERMS that is basically x^(i+64) mod G and one for LO64_TERMS + # that is basically x^i mod G. The exact values however undergo a + # couple adjustments, described below. + mults = [] + for j in [64, 0]: + pow_of_x = i + j + if lsb: + # Each 64x64 => 128 bit carryless multiplication instruction + # actually generates a 127-bit product in physical bits 0 + # through 126, which in the lsb-first case represent the + # coefficients of x^1 through x^127, not x^0 through x^126. + # Thus in the lsb-first case, each such instruction + # implicitly adds an extra factor of x. The below removes a + # factor of x from each constant to compensate for this. + # For n < 64 the x could be removed from either the reduced + # part or unreduced part, but for n == 64 the reduced part + # is the only option. Just always use the reduced part. + pow_of_x -= 1 + # Make a factor of x^(64-n) be applied unreduced rather than + # reduced, to cause the product to use only the x^(64-n) and + # higher terms and always be zero in the lower terms. Usually + # this makes no difference as it does not affect the product's + # congruence class mod G and the constant remains 64-bit, but + # part of the final reduction from 128 bits does rely on this + # property when it reuses one of the constants. + pow_of_x -= 64 - n + mults.append({ 'val': reduce(1 << pow_of_x, G) << (64 - n), + 'desc': f'(x^{pow_of_x} mod G) * x^{64-n}' }) + print_mult_pair(v, mults) + print('\t},') + + # Shuffle table for handling 1..15 bytes at end + print('\t.shuf_table = {') + print('\t\t' + (16*'-1, ').rstrip()) + print('\t\t' + ''.join(f'{i:2}, ' for i in range(16)).rstrip()) + print('\t\t' + (16*'-1, ').rstrip()) + print('\t},') + + # Barrett reduction constants for reducing 128 bits to the final CRC + print('\t.barrett_reduction_consts = {') + mults = [] + + val = div(1 << (63+n), G) + desc = f'floor(x^{63+n} / G)' + if not lsb: + val = (val << 1) - (1 << 64) + desc = f'({desc} * x) - x^64' + mults.append({ 'val': val, 'desc': desc }) + + val = G - (1 << n) + desc = f'G - x^{n}' + if lsb and n == 64: + assert (val & 1) != 0 # The x^0 term should always be nonzero. + val >>= 1 + desc = f'({desc} - x^0) / x' + else: + pow_of_x = 64 - n - (1 if lsb else 0) + val <<= pow_of_x + desc = f'({desc}) * x^{pow_of_x}' + mults.append({ 'val': val, 'desc': desc }) + + print_mult_pair(v, mults) + print('\t},') + + print('};') + +def parse_crc_variants(vars_string): + variants = [] + for var_string in vars_string.split(','): + bits, bit_order, generator_poly = var_string.split('_') + assert bits.startswith('crc') + bits = int(bits.removeprefix('crc')) + assert generator_poly.startswith('0x') + generator_poly = generator_poly.removeprefix('0x') + assert len(generator_poly) % 2 == 0 + generator_poly = int(generator_poly, 16) + variants.append(CrcVariant(bits, generator_poly, bit_order)) + return variants + +if len(sys.argv) != 3: + sys.stderr.write(f'Usage: {sys.argv[0]} CONSTS_TYPE[,CONSTS_TYPE]... CRC_VARIANT[,CRC_VARIANT]...\n') + sys.stderr.write(' CONSTS_TYPE can be sliceby[1-8] or x86_pclmul\n') + sys.stderr.write(' CRC_VARIANT is crc${num_bits}_${bit_order}_${generator_poly_as_hex}\n') + sys.stderr.write(' E.g. crc16_msb_0x8bb7 or crc32_lsb_0xedb88320\n') + sys.stderr.write(' Polynomial must use the given bit_order and exclude x^{num_bits}\n') + sys.exit(1) + +print('/* SPDX-License-Identifier: GPL-2.0-or-later */') +print('/*') +print(' * CRC constants generated by:') +print(' *') +print(f' *\t{sys.argv[0]} {" ".join(sys.argv[1:])}') +print(' *') +print(' * Do not edit manually.') +print(' */') +consts_types = sys.argv[1].split(',') +variants = parse_crc_variants(sys.argv[2]) +for consts_type in consts_types: + if consts_type.startswith('sliceby'): + gen_slicebyN_tables(variants, int(consts_type.removeprefix('sliceby'))) + elif consts_type == 'x86_pclmul': + gen_x86_pclmul_consts(variants) + else: + raise ValueError(f'Unknown consts_type: {consts_type}') -- 2.49.0.391.g4bbb303af6 From 2eaa8ef2cc85b6e23313a3d15d972ce132450584 Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Sun, 20 Apr 2025 15:32:39 +0200 Subject: [PATCH 7/9] fixes Signed-off-by: Peter Jung --- arch/Kconfig | 4 +- arch/x86/tools/insn_decoder_test.c | 3 +- crypto/crc32c_generic.c | 1 + drivers/gpu/drm/amd/amdgpu/mes_v11_0.c | 4 + drivers/gpu/drm/amd/amdgpu/mes_v12_0.c | 21 +++-- .../amd/display/dc/dml2/dml21/dml21_wrapper.c | 28 +++++-- .../drm/amd/display/dc/dml2/dml2_wrapper.c | 15 +++- drivers/gpu/drm/i915/display/intel_dsb.c | 4 + .../wireless/intel/iwlwifi/pcie/trans-gen2.c | 8 +- lib/Kconfig.debug | 9 ++ lib/Makefile | 2 + lib/longest_symbol_kunit.c | 82 +++++++++++++++++++ scripts/package/PKGBUILD | 5 ++ 13 files changed, 164 insertions(+), 22 deletions(-) create mode 100644 lib/longest_symbol_kunit.c diff --git a/arch/Kconfig b/arch/Kconfig index b8a4ff365582..9b087f9bb413 100644 --- a/arch/Kconfig +++ b/arch/Kconfig @@ -1137,7 +1137,7 @@ config ARCH_MMAP_RND_BITS int "Number of bits to use for ASLR of mmap base address" if EXPERT range ARCH_MMAP_RND_BITS_MIN ARCH_MMAP_RND_BITS_MAX default ARCH_MMAP_RND_BITS_DEFAULT if ARCH_MMAP_RND_BITS_DEFAULT - default ARCH_MMAP_RND_BITS_MIN + default ARCH_MMAP_RND_BITS_MAX depends on HAVE_ARCH_MMAP_RND_BITS help This value can be used to select the number of bits to use to @@ -1171,7 +1171,7 @@ config ARCH_MMAP_RND_COMPAT_BITS int "Number of bits to use for ASLR of mmap base address for compatible applications" if EXPERT range ARCH_MMAP_RND_COMPAT_BITS_MIN ARCH_MMAP_RND_COMPAT_BITS_MAX default ARCH_MMAP_RND_COMPAT_BITS_DEFAULT if ARCH_MMAP_RND_COMPAT_BITS_DEFAULT - default ARCH_MMAP_RND_COMPAT_BITS_MIN + default ARCH_MMAP_RND_COMPAT_BITS_MAX depends on HAVE_ARCH_MMAP_RND_COMPAT_BITS help This value can be used to select the number of bits to use to diff --git a/arch/x86/tools/insn_decoder_test.c b/arch/x86/tools/insn_decoder_test.c index 472540aeabc2..6c2986d2ad11 100644 --- a/arch/x86/tools/insn_decoder_test.c +++ b/arch/x86/tools/insn_decoder_test.c @@ -10,6 +10,7 @@ #include #include #include +#include #define unlikely(cond) (cond) @@ -106,7 +107,7 @@ static void parse_args(int argc, char **argv) } } -#define BUFSIZE 256 +#define BUFSIZE (256 + KSYM_NAME_LEN) int main(int argc, char **argv) { diff --git a/crypto/crc32c_generic.c b/crypto/crc32c_generic.c index 985da981d6e2..99713d0c77c0 100644 --- a/crypto/crc32c_generic.c +++ b/crypto/crc32c_generic.c @@ -220,3 +220,4 @@ MODULE_DESCRIPTION("CRC32c (Castagnoli) calculations wrapper for lib/crc32c"); MODULE_LICENSE("GPL"); MODULE_ALIAS_CRYPTO("crc32c"); MODULE_ALIAS_CRYPTO("crc32c-generic"); +MODULE_ALIAS_CRYPTO("crc32c-intel"); diff --git a/drivers/gpu/drm/amd/amdgpu/mes_v11_0.c b/drivers/gpu/drm/amd/amdgpu/mes_v11_0.c index f9a4d08eef92..0f808ffcab94 100644 --- a/drivers/gpu/drm/amd/amdgpu/mes_v11_0.c +++ b/drivers/gpu/drm/amd/amdgpu/mes_v11_0.c @@ -899,6 +899,10 @@ static void mes_v11_0_get_fw_version(struct amdgpu_device *adev) { int pipe; + /* return early if we have already fetched these */ + if (adev->mes.sched_version && adev->mes.kiq_version) + return; + /* get MES scheduler/KIQ versions */ mutex_lock(&adev->srbm_mutex); diff --git a/drivers/gpu/drm/amd/amdgpu/mes_v12_0.c b/drivers/gpu/drm/amd/amdgpu/mes_v12_0.c index 0fd0fa6ed518..6b121c2723d6 100644 --- a/drivers/gpu/drm/amd/amdgpu/mes_v12_0.c +++ b/drivers/gpu/drm/amd/amdgpu/mes_v12_0.c @@ -1390,17 +1390,20 @@ static int mes_v12_0_queue_init(struct amdgpu_device *adev, mes_v12_0_queue_init_register(ring); } - /* get MES scheduler/KIQ versions */ - mutex_lock(&adev->srbm_mutex); - soc21_grbm_select(adev, 3, pipe, 0, 0); + if (((pipe == AMDGPU_MES_SCHED_PIPE) && !adev->mes.sched_version) || + ((pipe == AMDGPU_MES_KIQ_PIPE) && !adev->mes.kiq_version)) { + /* get MES scheduler/KIQ versions */ + mutex_lock(&adev->srbm_mutex); + soc21_grbm_select(adev, 3, pipe, 0, 0); - if (pipe == AMDGPU_MES_SCHED_PIPE) - adev->mes.sched_version = RREG32_SOC15(GC, 0, regCP_MES_GP3_LO); - else if (pipe == AMDGPU_MES_KIQ_PIPE && adev->enable_mes_kiq) - adev->mes.kiq_version = RREG32_SOC15(GC, 0, regCP_MES_GP3_LO); + if (pipe == AMDGPU_MES_SCHED_PIPE) + adev->mes.sched_version = RREG32_SOC15(GC, 0, regCP_MES_GP3_LO); + else if (pipe == AMDGPU_MES_KIQ_PIPE && adev->enable_mes_kiq) + adev->mes.kiq_version = RREG32_SOC15(GC, 0, regCP_MES_GP3_LO); - soc21_grbm_select(adev, 0, 0, 0, 0); - mutex_unlock(&adev->srbm_mutex); + soc21_grbm_select(adev, 0, 0, 0, 0); + mutex_unlock(&adev->srbm_mutex); + } return 0; } diff --git a/drivers/gpu/drm/amd/display/dc/dml2/dml21/dml21_wrapper.c b/drivers/gpu/drm/amd/display/dc/dml2/dml21/dml21_wrapper.c index fb80ba9287b6..d6fd13f43c08 100644 --- a/drivers/gpu/drm/amd/display/dc/dml2/dml21/dml21_wrapper.c +++ b/drivers/gpu/drm/amd/display/dc/dml2/dml21/dml21_wrapper.c @@ -2,6 +2,7 @@ // // Copyright 2024 Advanced Micro Devices, Inc. +#include #include "dml2_internal_types.h" #include "dml_top.h" @@ -13,11 +14,11 @@ static bool dml21_allocate_memory(struct dml2_context **dml_ctx) { - *dml_ctx = kzalloc(sizeof(struct dml2_context), GFP_KERNEL); + *dml_ctx = vzalloc(sizeof(struct dml2_context)); if (!(*dml_ctx)) return false; - (*dml_ctx)->v21.dml_init.dml2_instance = kzalloc(sizeof(struct dml2_instance), GFP_KERNEL); + (*dml_ctx)->v21.dml_init.dml2_instance = vzalloc(sizeof(struct dml2_instance)); if (!((*dml_ctx)->v21.dml_init.dml2_instance)) return false; @@ -27,7 +28,7 @@ static bool dml21_allocate_memory(struct dml2_context **dml_ctx) (*dml_ctx)->v21.mode_support.display_config = &(*dml_ctx)->v21.display_config; (*dml_ctx)->v21.mode_programming.display_config = (*dml_ctx)->v21.mode_support.display_config; - (*dml_ctx)->v21.mode_programming.programming = kzalloc(sizeof(struct dml2_display_cfg_programming), GFP_KERNEL); + (*dml_ctx)->v21.mode_programming.programming = vzalloc(sizeof(struct dml2_display_cfg_programming)); if (!((*dml_ctx)->v21.mode_programming.programming)) return false; @@ -86,6 +87,8 @@ static void dml21_init(const struct dc *in_dc, struct dml2_context **dml_ctx, co /* Store configuration options */ (*dml_ctx)->config = *config; + DC_FP_START(); + /*Initialize SOCBB and DCNIP params */ dml21_initialize_soc_bb_params(&(*dml_ctx)->v21.dml_init, config, in_dc); dml21_initialize_ip_params(&(*dml_ctx)->v21.dml_init, config, in_dc); @@ -96,6 +99,8 @@ static void dml21_init(const struct dc *in_dc, struct dml2_context **dml_ctx, co /*Initialize DML21 instance */ dml2_initialize_instance(&(*dml_ctx)->v21.dml_init); + + DC_FP_END(); } bool dml21_create(const struct dc *in_dc, struct dml2_context **dml_ctx, const struct dml2_configuration_options *config) @@ -111,8 +116,8 @@ bool dml21_create(const struct dc *in_dc, struct dml2_context **dml_ctx, const s void dml21_destroy(struct dml2_context *dml2) { - kfree(dml2->v21.dml_init.dml2_instance); - kfree(dml2->v21.mode_programming.programming); + vfree(dml2->v21.dml_init.dml2_instance); + vfree(dml2->v21.mode_programming.programming); } static void dml21_calculate_rq_and_dlg_params(const struct dc *dc, struct dc_state *context, struct resource_context *out_new_hw_state, @@ -269,11 +274,16 @@ bool dml21_validate(const struct dc *in_dc, struct dc_state *context, struct dml { bool out = false; + DC_FP_START(); + /* Use dml_validate_only for fast_validate path */ - if (fast_validate) { + if (fast_validate) out = dml21_check_mode_support(in_dc, context, dml_ctx); - } else + else out = dml21_mode_check_and_programming(in_dc, context, dml_ctx); + + DC_FP_END(); + return out; } @@ -412,8 +422,12 @@ void dml21_copy(struct dml2_context *dst_dml_ctx, dst_dml_ctx->v21.mode_programming.programming = dst_dml2_programming; + DC_FP_START(); + /* need to initialize copied instance for internal references to be correct */ dml2_initialize_instance(&dst_dml_ctx->v21.dml_init); + + DC_FP_END(); } bool dml21_create_copy(struct dml2_context **dst_dml_ctx, diff --git a/drivers/gpu/drm/amd/display/dc/dml2/dml2_wrapper.c b/drivers/gpu/drm/amd/display/dc/dml2/dml2_wrapper.c index 68b882d28195..d0f9df2daeb4 100644 --- a/drivers/gpu/drm/amd/display/dc/dml2/dml2_wrapper.c +++ b/drivers/gpu/drm/amd/display/dc/dml2/dml2_wrapper.c @@ -24,6 +24,8 @@ * */ +#include + #include "display_mode_core.h" #include "dml2_internal_types.h" #include "dml2_utils.h" @@ -732,17 +734,22 @@ bool dml2_validate(const struct dc *in_dc, struct dc_state *context, struct dml2 return out; } + DC_FP_START(); + /* Use dml_validate_only for fast_validate path */ if (fast_validate) out = dml2_validate_only(context); else out = dml2_validate_and_build_resource(in_dc, context); + + DC_FP_END(); + return out; } static inline struct dml2_context *dml2_allocate_memory(void) { - return (struct dml2_context *) kzalloc(sizeof(struct dml2_context), GFP_KERNEL); + return (struct dml2_context *) vzalloc(sizeof(struct dml2_context)); } static void dml2_init(const struct dc *in_dc, const struct dml2_configuration_options *config, struct dml2_context **dml2) @@ -776,11 +783,15 @@ static void dml2_init(const struct dc *in_dc, const struct dml2_configuration_op break; } + DC_FP_START(); + initialize_dml2_ip_params(*dml2, in_dc, &(*dml2)->v20.dml_core_ctx.ip); initialize_dml2_soc_bbox(*dml2, in_dc, &(*dml2)->v20.dml_core_ctx.soc); initialize_dml2_soc_states(*dml2, in_dc, &(*dml2)->v20.dml_core_ctx.soc, &(*dml2)->v20.dml_core_ctx.states); + + DC_FP_END(); } bool dml2_create(const struct dc *in_dc, const struct dml2_configuration_options *config, struct dml2_context **dml2) @@ -806,7 +817,7 @@ void dml2_destroy(struct dml2_context *dml2) if (dml2->architecture == dml2_architecture_21) dml21_destroy(dml2); - kfree(dml2); + vfree(dml2); } void dml2_extract_dram_and_fclk_change_support(struct dml2_context *dml2, diff --git a/drivers/gpu/drm/i915/display/intel_dsb.c b/drivers/gpu/drm/i915/display/intel_dsb.c index e6f8fc743fb4..73f6febfb6c4 100644 --- a/drivers/gpu/drm/i915/display/intel_dsb.c +++ b/drivers/gpu/drm/i915/display/intel_dsb.c @@ -763,6 +763,10 @@ struct intel_dsb *intel_dsb_prepare(struct intel_atomic_state *state, if (!i915->display.params.enable_dsb) return NULL; + /* TODO: DSB is broken in Xe KMD, so disabling it until fixed */ + if (!IS_ENABLED(I915)) + return NULL; + dsb = kzalloc(sizeof(*dsb), GFP_KERNEL); if (!dsb) goto out; diff --git a/drivers/net/wireless/intel/iwlwifi/pcie/trans-gen2.c b/drivers/net/wireless/intel/iwlwifi/pcie/trans-gen2.c index 793514a1852a..e37fa5ae97f6 100644 --- a/drivers/net/wireless/intel/iwlwifi/pcie/trans-gen2.c +++ b/drivers/net/wireless/intel/iwlwifi/pcie/trans-gen2.c @@ -147,8 +147,14 @@ static void _iwl_trans_pcie_gen2_stop_device(struct iwl_trans *trans) return; if (trans->state >= IWL_TRANS_FW_STARTED && - trans_pcie->fw_reset_handshake) + trans_pcie->fw_reset_handshake) { + /* + * Reset handshake can dump firmware on timeout, but that + * should assume that the firmware is already dead. + */ + trans->state = IWL_TRANS_NO_FW; iwl_trans_pcie_fw_reset_handshake(trans); + } trans_pcie->is_down = true; diff --git a/lib/Kconfig.debug b/lib/Kconfig.debug index 35796c290ca3..a6da7f4411f4 100644 --- a/lib/Kconfig.debug +++ b/lib/Kconfig.debug @@ -2838,6 +2838,15 @@ config FORTIFY_KUNIT_TEST by the str*() and mem*() family of functions. For testing runtime traps of FORTIFY_SOURCE, see LKDTM's "FORTIFY_*" tests. +config LONGEST_SYM_KUNIT_TEST + tristate "Test the longest symbol possible" if !KUNIT_ALL_TESTS + depends on KUNIT && KPROBES + default KUNIT_ALL_TESTS + help + Tests the longest symbol possible + + If unsure, say N. + config HW_BREAKPOINT_KUNIT_TEST bool "Test hw_breakpoint constraints accounting" if !KUNIT_ALL_TESTS depends on HAVE_HW_BREAKPOINT diff --git a/lib/Makefile b/lib/Makefile index 4f3d00a2fd65..9a54526008d0 100644 --- a/lib/Makefile +++ b/lib/Makefile @@ -398,6 +398,8 @@ obj-$(CONFIG_FORTIFY_KUNIT_TEST) += fortify_kunit.o obj-$(CONFIG_CRC_KUNIT_TEST) += crc_kunit.o obj-$(CONFIG_SIPHASH_KUNIT_TEST) += siphash_kunit.o obj-$(CONFIG_USERCOPY_KUNIT_TEST) += usercopy_kunit.o +obj-$(CONFIG_LONGEST_SYM_KUNIT_TEST) += longest_symbol_kunit.o +CFLAGS_longest_symbol_kunit.o += $(call cc-disable-warning, missing-prototypes) obj-$(CONFIG_GENERIC_LIB_DEVMEM_IS_ALLOWED) += devmem_is_allowed.o diff --git a/lib/longest_symbol_kunit.c b/lib/longest_symbol_kunit.c new file mode 100644 index 000000000000..e3c28ff1807f --- /dev/null +++ b/lib/longest_symbol_kunit.c @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Test the longest symbol length. Execute with: + * ./tools/testing/kunit/kunit.py run longest-symbol + * --arch=x86_64 --kconfig_add CONFIG_KPROBES=y --kconfig_add CONFIG_MODULES=y + * --kconfig_add CONFIG_RETPOLINE=n --kconfig_add CONFIG_CFI_CLANG=n + * --kconfig_add CONFIG_MITIGATION_RETPOLINE=n + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include + +#define DI(name) s##name##name +#define DDI(name) DI(n##name##name) +#define DDDI(name) DDI(n##name##name) +#define DDDDI(name) DDDI(n##name##name) +#define DDDDDI(name) DDDDI(n##name##name) + +/*Generate a symbol whose name length is 511 */ +#define LONGEST_SYM_NAME DDDDDI(g1h2i3j4k5l6m7n) + +#define RETURN_LONGEST_SYM 0xAAAAA + +noinline int LONGEST_SYM_NAME(void); +noinline int LONGEST_SYM_NAME(void) +{ + return RETURN_LONGEST_SYM; +} + +_Static_assert(sizeof(__stringify(LONGEST_SYM_NAME)) == KSYM_NAME_LEN, +"Incorrect symbol length found. Expected KSYM_NAME_LEN: " +__stringify(KSYM_NAME_LEN) ", but found: " +__stringify(sizeof(LONGEST_SYM_NAME))); + +static void test_longest_symbol(struct kunit *test) +{ + KUNIT_EXPECT_EQ(test, RETURN_LONGEST_SYM, LONGEST_SYM_NAME()); +}; + +static void test_longest_symbol_kallsyms(struct kunit *test) +{ + unsigned long (*kallsyms_lookup_name)(const char *name); + static int (*longest_sym)(void); + + struct kprobe kp = { + .symbol_name = "kallsyms_lookup_name", + }; + + if (register_kprobe(&kp) < 0) { + pr_info("%s: kprobe not registered", __func__); + KUNIT_FAIL(test, "test_longest_symbol kallsyms: kprobe not registered\n"); + return; + } + + kunit_warn(test, "test_longest_symbol kallsyms: kprobe registered\n"); + kallsyms_lookup_name = (unsigned long (*)(const char *name))kp.addr; + unregister_kprobe(&kp); + + longest_sym = + (void *) kallsyms_lookup_name(__stringify(LONGEST_SYM_NAME)); + KUNIT_EXPECT_EQ(test, RETURN_LONGEST_SYM, longest_sym()); +}; + +static struct kunit_case longest_symbol_test_cases[] = { + KUNIT_CASE(test_longest_symbol), + KUNIT_CASE(test_longest_symbol_kallsyms), + {} +}; + +static struct kunit_suite longest_symbol_test_suite = { + .name = "longest-symbol", + .test_cases = longest_symbol_test_cases, +}; +kunit_test_suite(longest_symbol_test_suite); + +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Test the longest symbol length"); +MODULE_AUTHOR("Sergio González Collado"); diff --git a/scripts/package/PKGBUILD b/scripts/package/PKGBUILD index 0cf3a55b05e1..a27d4344a4e8 100644 --- a/scripts/package/PKGBUILD +++ b/scripts/package/PKGBUILD @@ -90,6 +90,11 @@ _package-headers() { "${srctree}/scripts/package/install-extmod-build" "${builddir}" fi + # required when DEBUG_INFO_BTF_MODULES is enabled + if [ -f tools/bpf/resolve_btfids/resolve_btfids ]; then + install -Dt "$builddir/tools/bpf/resolve_btfids" tools/bpf/resolve_btfids/resolve_btfids + fi + echo "Installing System.map and config..." mkdir -p "${builddir}" cp System.map "${builddir}/System.map" -- 2.49.0.391.g4bbb303af6 From e6ce0ab27dc8974803142571eda482e1ac737c97 Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Sun, 20 Apr 2025 15:33:19 +0200 Subject: [PATCH 8/9] t2 Signed-off-by: Peter Jung --- .../ABI/testing/sysfs-driver-hid-appletb-kbd | 13 + MAINTAINERS | 8 + drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c | 3 + drivers/gpu/drm/drm_format_helper.c | 54 + drivers/gpu/drm/i915/display/intel_ddi.c | 4 + drivers/gpu/drm/i915/display/intel_fbdev.c | 6 +- drivers/gpu/drm/i915/display/intel_quirks.c | 15 + drivers/gpu/drm/i915/display/intel_quirks.h | 1 + .../gpu/drm/tests/drm_format_helper_test.c | 81 ++ drivers/gpu/drm/tiny/Kconfig | 12 + drivers/gpu/drm/tiny/Makefile | 1 + drivers/gpu/drm/tiny/appletbdrm.c | 840 ++++++++++++ drivers/gpu/vga/vga_switcheroo.c | 7 +- drivers/hid/Kconfig | 36 +- drivers/hid/Makefile | 6 + drivers/hid/dockchannel-hid/Kconfig | 14 + drivers/hid/dockchannel-hid/Makefile | 6 + drivers/hid/dockchannel-hid/dockchannel-hid.c | 1213 +++++++++++++++++ drivers/hid/hid-apple.c | 91 +- drivers/hid/hid-appletb-bl.c | 204 +++ drivers/hid/hid-appletb-kbd.c | 507 +++++++ drivers/hid/hid-core.c | 11 +- drivers/hid/hid-ids.h | 25 +- drivers/hid/hid-magicmouse.c | 908 ++++++++++-- drivers/hid/hid-multitouch.c | 60 +- drivers/hid/hid-quirks.c | 17 +- drivers/hid/spi-hid/Kconfig | 26 + drivers/hid/spi-hid/Makefile | 10 + drivers/hid/spi-hid/spi-hid-apple-core.c | 1194 ++++++++++++++++ drivers/hid/spi-hid/spi-hid-apple-of.c | 151 ++ drivers/hid/spi-hid/spi-hid-apple.h | 35 + drivers/hwmon/applesmc.c | 1138 ++++++++++++---- drivers/nvme/host/apple.c | 2 +- drivers/pci/vgaarb.c | 1 + drivers/platform/x86/apple-gmux.c | 18 + drivers/soc/apple/Kconfig | 24 + drivers/soc/apple/Makefile | 6 + drivers/soc/apple/dockchannel.c | 406 ++++++ drivers/soc/apple/rtkit-helper.c | 151 ++ drivers/soc/apple/rtkit.c | 2 +- drivers/staging/Kconfig | 2 + drivers/staging/Makefile | 1 + drivers/staging/apple-bce/Kconfig | 18 + drivers/staging/apple-bce/Makefile | 28 + drivers/staging/apple-bce/apple_bce.c | 445 ++++++ drivers/staging/apple-bce/apple_bce.h | 38 + drivers/staging/apple-bce/audio/audio.c | 711 ++++++++++ drivers/staging/apple-bce/audio/audio.h | 125 ++ drivers/staging/apple-bce/audio/description.h | 42 + drivers/staging/apple-bce/audio/pcm.c | 308 +++++ drivers/staging/apple-bce/audio/pcm.h | 16 + drivers/staging/apple-bce/audio/protocol.c | 347 +++++ drivers/staging/apple-bce/audio/protocol.h | 147 ++ .../staging/apple-bce/audio/protocol_bce.c | 226 +++ .../staging/apple-bce/audio/protocol_bce.h | 72 + drivers/staging/apple-bce/mailbox.c | 151 ++ drivers/staging/apple-bce/mailbox.h | 53 + drivers/staging/apple-bce/queue.c | 390 ++++++ drivers/staging/apple-bce/queue.h | 177 +++ drivers/staging/apple-bce/queue_dma.c | 220 +++ drivers/staging/apple-bce/queue_dma.h | 50 + drivers/staging/apple-bce/vhci/command.h | 204 +++ drivers/staging/apple-bce/vhci/queue.c | 268 ++++ drivers/staging/apple-bce/vhci/queue.h | 76 ++ drivers/staging/apple-bce/vhci/transfer.c | 661 +++++++++ drivers/staging/apple-bce/vhci/transfer.h | 73 + drivers/staging/apple-bce/vhci/vhci.c | 759 +++++++++++ drivers/staging/apple-bce/vhci/vhci.h | 52 + include/drm/drm_format_helper.h | 3 + include/linux/hid.h | 6 +- include/linux/soc/apple/dockchannel.h | 26 + include/linux/soc/apple/rtkit.h | 2 +- 72 files changed, 12560 insertions(+), 444 deletions(-) create mode 100644 Documentation/ABI/testing/sysfs-driver-hid-appletb-kbd create mode 100644 drivers/gpu/drm/tiny/appletbdrm.c create mode 100644 drivers/hid/dockchannel-hid/Kconfig create mode 100644 drivers/hid/dockchannel-hid/Makefile create mode 100644 drivers/hid/dockchannel-hid/dockchannel-hid.c create mode 100644 drivers/hid/hid-appletb-bl.c create mode 100644 drivers/hid/hid-appletb-kbd.c create mode 100644 drivers/hid/spi-hid/Kconfig create mode 100644 drivers/hid/spi-hid/Makefile create mode 100644 drivers/hid/spi-hid/spi-hid-apple-core.c create mode 100644 drivers/hid/spi-hid/spi-hid-apple-of.c create mode 100644 drivers/hid/spi-hid/spi-hid-apple.h create mode 100644 drivers/soc/apple/dockchannel.c create mode 100644 drivers/soc/apple/rtkit-helper.c create mode 100644 drivers/staging/apple-bce/Kconfig create mode 100644 drivers/staging/apple-bce/Makefile create mode 100644 drivers/staging/apple-bce/apple_bce.c create mode 100644 drivers/staging/apple-bce/apple_bce.h create mode 100644 drivers/staging/apple-bce/audio/audio.c create mode 100644 drivers/staging/apple-bce/audio/audio.h create mode 100644 drivers/staging/apple-bce/audio/description.h create mode 100644 drivers/staging/apple-bce/audio/pcm.c create mode 100644 drivers/staging/apple-bce/audio/pcm.h create mode 100644 drivers/staging/apple-bce/audio/protocol.c create mode 100644 drivers/staging/apple-bce/audio/protocol.h create mode 100644 drivers/staging/apple-bce/audio/protocol_bce.c create mode 100644 drivers/staging/apple-bce/audio/protocol_bce.h create mode 100644 drivers/staging/apple-bce/mailbox.c create mode 100644 drivers/staging/apple-bce/mailbox.h create mode 100644 drivers/staging/apple-bce/queue.c create mode 100644 drivers/staging/apple-bce/queue.h create mode 100644 drivers/staging/apple-bce/queue_dma.c create mode 100644 drivers/staging/apple-bce/queue_dma.h create mode 100644 drivers/staging/apple-bce/vhci/command.h create mode 100644 drivers/staging/apple-bce/vhci/queue.c create mode 100644 drivers/staging/apple-bce/vhci/queue.h create mode 100644 drivers/staging/apple-bce/vhci/transfer.c create mode 100644 drivers/staging/apple-bce/vhci/transfer.h create mode 100644 drivers/staging/apple-bce/vhci/vhci.c create mode 100644 drivers/staging/apple-bce/vhci/vhci.h create mode 100644 include/linux/soc/apple/dockchannel.h diff --git a/Documentation/ABI/testing/sysfs-driver-hid-appletb-kbd b/Documentation/ABI/testing/sysfs-driver-hid-appletb-kbd new file mode 100644 index 000000000000..2a19584d091e --- /dev/null +++ b/Documentation/ABI/testing/sysfs-driver-hid-appletb-kbd @@ -0,0 +1,13 @@ +What: /sys/bus/hid/drivers/hid-appletb-kbd//mode +Date: September, 2023 +KernelVersion: 6.5 +Contact: linux-input@vger.kernel.org +Description: + The set of keys displayed on the Touch Bar. + Valid values are: + == ================= + 0 Escape key only + 1 Function keys + 2 Media/brightness keys + 3 None + == ================= diff --git a/MAINTAINERS b/MAINTAINERS index 3e00f5654f60..2b1f3e8bdbdd 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -7164,6 +7164,14 @@ S: Supported T: git https://gitlab.freedesktop.org/drm/misc/kernel.git F: drivers/gpu/drm/sun4i/sun8i* +DRM DRIVER FOR APPLE TOUCH BARS +M: Aun-Ali Zaidi +M: Aditya Garg +L: dri-devel@lists.freedesktop.org +S: Maintained +T: git https://gitlab.freedesktop.org/drm/misc/kernel.git +F: drivers/gpu/drm/tiny/appletbdrm.c + DRM DRIVER FOR ARM PL111 CLCD M: Linus Walleij S: Maintained diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c b/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c index 73b1a742c5e4..b7c6a2dd8d35 100644 --- a/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c +++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c @@ -2258,6 +2258,9 @@ static int amdgpu_pci_probe(struct pci_dev *pdev, int ret, retry = 0, i; bool supports_atomic = false; + if (vga_switcheroo_client_probe_defer(pdev)) + return -EPROBE_DEFER; + /* skip devices which are owned by radeon */ for (i = 0; i < ARRAY_SIZE(amdgpu_unsupported_pciidlist); i++) { if (amdgpu_unsupported_pciidlist[i] == pdev->device) diff --git a/drivers/gpu/drm/drm_format_helper.c b/drivers/gpu/drm/drm_format_helper.c index b1be458ed4dd..4f60c8d8f63e 100644 --- a/drivers/gpu/drm/drm_format_helper.c +++ b/drivers/gpu/drm/drm_format_helper.c @@ -702,6 +702,57 @@ void drm_fb_xrgb8888_to_rgb888(struct iosys_map *dst, const unsigned int *dst_pi } EXPORT_SYMBOL(drm_fb_xrgb8888_to_rgb888); +static void drm_fb_xrgb8888_to_bgr888_line(void *dbuf, const void *sbuf, unsigned int pixels) +{ + u8 *dbuf8 = dbuf; + const __le32 *sbuf32 = sbuf; + unsigned int x; + u32 pix; + + for (x = 0; x < pixels; x++) { + pix = le32_to_cpu(sbuf32[x]); + /* write red-green-blue to output in little endianness */ + *dbuf8++ = (pix & 0x00ff0000) >> 16; + *dbuf8++ = (pix & 0x0000ff00) >> 8; + *dbuf8++ = (pix & 0x000000ff) >> 0; + } +} + +/** + * drm_fb_xrgb8888_to_bgr888 - Convert XRGB8888 to BGR888 clip buffer + * @dst: Array of BGR888 destination buffers + * @dst_pitch: Array of numbers of bytes between the start of two consecutive scanlines + * within @dst; can be NULL if scanlines are stored next to each other. + * @src: Array of XRGB8888 source buffers + * @fb: DRM framebuffer + * @clip: Clip rectangle area to copy + * @state: Transform and conversion state + * + * This function copies parts of a framebuffer to display memory and converts the + * color format during the process. Destination and framebuffer formats must match. The + * parameters @dst, @dst_pitch and @src refer to arrays. Each array must have at + * least as many entries as there are planes in @fb's format. Each entry stores the + * value for the format's respective color plane at the same index. + * + * This function does not apply clipping on @dst (i.e. the destination is at the + * top-left corner). + * + * Drivers can use this function for BGR888 devices that don't natively + * support XRGB8888. + */ +void drm_fb_xrgb8888_to_bgr888(struct iosys_map *dst, const unsigned int *dst_pitch, + const struct iosys_map *src, const struct drm_framebuffer *fb, + const struct drm_rect *clip, struct drm_format_conv_state *state) +{ + static const u8 dst_pixsize[DRM_FORMAT_MAX_PLANES] = { + 3, + }; + + drm_fb_xfrm(dst, dst_pitch, dst_pixsize, src, fb, clip, false, state, + drm_fb_xrgb8888_to_bgr888_line); +} +EXPORT_SYMBOL(drm_fb_xrgb8888_to_bgr888); + static void drm_fb_xrgb8888_to_argb8888_line(void *dbuf, const void *sbuf, unsigned int pixels) { __le32 *dbuf32 = dbuf; @@ -1035,6 +1086,9 @@ int drm_fb_blit(struct iosys_map *dst, const unsigned int *dst_pitch, uint32_t d } else if (dst_format == DRM_FORMAT_RGB888) { drm_fb_xrgb8888_to_rgb888(dst, dst_pitch, src, fb, clip, state); return 0; + } else if (dst_format == DRM_FORMAT_BGR888) { + drm_fb_xrgb8888_to_bgr888(dst, dst_pitch, src, fb, clip, state); + return 0; } else if (dst_format == DRM_FORMAT_ARGB8888) { drm_fb_xrgb8888_to_argb8888(dst, dst_pitch, src, fb, clip, state); return 0; diff --git a/drivers/gpu/drm/i915/display/intel_ddi.c b/drivers/gpu/drm/i915/display/intel_ddi.c index ff2cf3daa7a2..bffff11bf040 100644 --- a/drivers/gpu/drm/i915/display/intel_ddi.c +++ b/drivers/gpu/drm/i915/display/intel_ddi.c @@ -4868,6 +4868,7 @@ static int intel_ddi_init_hdmi_connector(struct intel_digital_port *dig_port) static bool intel_ddi_a_force_4_lanes(struct intel_digital_port *dig_port) { + struct intel_display *display = to_intel_display(dig_port); struct drm_i915_private *dev_priv = to_i915(dig_port->base.base.dev); if (dig_port->base.port != PORT_A) @@ -4876,6 +4877,9 @@ static bool intel_ddi_a_force_4_lanes(struct intel_digital_port *dig_port) if (dig_port->ddi_a_4_lanes) return false; + if (intel_has_quirk(display, QUIRK_DDI_A_FORCE_4_LANES)) + return true; + /* Broxton/Geminilake: Bspec says that DDI_A_4_LANES is the only * supported configuration */ diff --git a/drivers/gpu/drm/i915/display/intel_fbdev.c b/drivers/gpu/drm/i915/display/intel_fbdev.c index 00852ff5b247..4c56f1b622be 100644 --- a/drivers/gpu/drm/i915/display/intel_fbdev.c +++ b/drivers/gpu/drm/i915/display/intel_fbdev.c @@ -197,10 +197,10 @@ static int intelfb_create(struct drm_fb_helper *helper, ifbdev->fb = NULL; if (fb && - (sizes->fb_width > fb->base.width || - sizes->fb_height > fb->base.height)) { + (sizes->fb_width != fb->base.width || + sizes->fb_height != fb->base.height)) { drm_dbg_kms(&dev_priv->drm, - "BIOS fb too small (%dx%d), we require (%dx%d)," + "BIOS fb not valid (%dx%d), we require (%dx%d)," " releasing it\n", fb->base.width, fb->base.height, sizes->fb_width, sizes->fb_height); diff --git a/drivers/gpu/drm/i915/display/intel_quirks.c b/drivers/gpu/drm/i915/display/intel_quirks.c index 8b30e9fd936e..2bab4111962d 100644 --- a/drivers/gpu/drm/i915/display/intel_quirks.c +++ b/drivers/gpu/drm/i915/display/intel_quirks.c @@ -64,6 +64,18 @@ static void quirk_increase_ddi_disabled_time(struct intel_display *display) drm_info(display->drm, "Applying Increase DDI Disabled quirk\n"); } +/* + * In some cases, the firmware might not set the lane count to 4 (for example, + * when booting in some dual GPU Macs with the dGPU as the default GPU), this + * quirk is used to force it as otherwise it might not be possible to compute a + * valid link configuration. + */ +static void quirk_ddi_a_force_4_lanes(struct intel_display *display) +{ + intel_set_quirk(display, QUIRK_DDI_A_FORCE_4_LANES); + drm_info(display->drm, "Applying DDI A Forced 4 Lanes quirk\n"); +} + static void quirk_no_pps_backlight_power_hook(struct intel_display *display) { intel_set_quirk(display, QUIRK_NO_PPS_BACKLIGHT_POWER_HOOK); @@ -229,6 +241,9 @@ static struct intel_quirk intel_quirks[] = { { 0x3184, 0x1019, 0xa94d, quirk_increase_ddi_disabled_time }, /* HP Notebook - 14-r206nv */ { 0x0f31, 0x103c, 0x220f, quirk_invert_brightness }, + + /* Apple MacBookPro15,1 */ + { 0x3e9b, 0x106b, 0x0176, quirk_ddi_a_force_4_lanes }, }; static const struct intel_dpcd_quirk intel_dpcd_quirks[] = { diff --git a/drivers/gpu/drm/i915/display/intel_quirks.h b/drivers/gpu/drm/i915/display/intel_quirks.h index cafdebda7535..a5296f82776e 100644 --- a/drivers/gpu/drm/i915/display/intel_quirks.h +++ b/drivers/gpu/drm/i915/display/intel_quirks.h @@ -20,6 +20,7 @@ enum intel_quirk_id { QUIRK_LVDS_SSC_DISABLE, QUIRK_NO_PPS_BACKLIGHT_POWER_HOOK, QUIRK_FW_SYNC_LEN, + QUIRK_DDI_A_FORCE_4_LANES, }; void intel_init_quirks(struct intel_display *display); diff --git a/drivers/gpu/drm/tests/drm_format_helper_test.c b/drivers/gpu/drm/tests/drm_format_helper_test.c index 08992636ec05..35cd3405d045 100644 --- a/drivers/gpu/drm/tests/drm_format_helper_test.c +++ b/drivers/gpu/drm/tests/drm_format_helper_test.c @@ -60,6 +60,11 @@ struct convert_to_rgb888_result { const u8 expected[TEST_BUF_SIZE]; }; +struct convert_to_bgr888_result { + unsigned int dst_pitch; + const u8 expected[TEST_BUF_SIZE]; +}; + struct convert_to_argb8888_result { unsigned int dst_pitch; const u32 expected[TEST_BUF_SIZE]; @@ -107,6 +112,7 @@ struct convert_xrgb8888_case { struct convert_to_argb1555_result argb1555_result; struct convert_to_rgba5551_result rgba5551_result; struct convert_to_rgb888_result rgb888_result; + struct convert_to_bgr888_result bgr888_result; struct convert_to_argb8888_result argb8888_result; struct convert_to_xrgb2101010_result xrgb2101010_result; struct convert_to_argb2101010_result argb2101010_result; @@ -151,6 +157,10 @@ static struct convert_xrgb8888_case convert_xrgb8888_cases[] = { .dst_pitch = TEST_USE_DEFAULT_PITCH, .expected = { 0x00, 0x00, 0xFF }, }, + .bgr888_result = { + .dst_pitch = TEST_USE_DEFAULT_PITCH, + .expected = { 0xFF, 0x00, 0x00 }, + }, .argb8888_result = { .dst_pitch = TEST_USE_DEFAULT_PITCH, .expected = { 0xFFFF0000 }, @@ -217,6 +227,10 @@ static struct convert_xrgb8888_case convert_xrgb8888_cases[] = { .dst_pitch = TEST_USE_DEFAULT_PITCH, .expected = { 0x00, 0x00, 0xFF }, }, + .bgr888_result = { + .dst_pitch = TEST_USE_DEFAULT_PITCH, + .expected = { 0xFF, 0x00, 0x00 }, + }, .argb8888_result = { .dst_pitch = TEST_USE_DEFAULT_PITCH, .expected = { 0xFFFF0000 }, @@ -330,6 +344,15 @@ static struct convert_xrgb8888_case convert_xrgb8888_cases[] = { 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, }, }, + .bgr888_result = { + .dst_pitch = TEST_USE_DEFAULT_PITCH, + .expected = { + 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, + 0xFF, 0x00, 0x00, 0x00, 0xFF, 0x00, + 0x00, 0x00, 0xFF, 0xFF, 0x00, 0xFF, + 0xFF, 0xFF, 0x00, 0x00, 0xFF, 0xFF, + }, + }, .argb8888_result = { .dst_pitch = TEST_USE_DEFAULT_PITCH, .expected = { @@ -468,6 +491,17 @@ static struct convert_xrgb8888_case convert_xrgb8888_cases[] = { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, }, + .bgr888_result = { + .dst_pitch = 15, + .expected = { + 0x0E, 0x44, 0x9C, 0x11, 0x4D, 0x05, 0xA8, 0xF3, 0x03, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x6C, 0xF0, 0x73, 0x0E, 0x44, 0x9C, 0x11, 0x4D, 0x05, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xA8, 0x03, 0x03, 0x6C, 0xF0, 0x73, 0x0E, 0x44, 0x9C, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + }, .argb8888_result = { .dst_pitch = 20, .expected = { @@ -914,6 +948,52 @@ static void drm_test_fb_xrgb8888_to_rgb888(struct kunit *test) KUNIT_EXPECT_MEMEQ(test, buf, result->expected, dst_size); } +static void drm_test_fb_xrgb8888_to_bgr888(struct kunit *test) +{ + const struct convert_xrgb8888_case *params = test->param_value; + const struct convert_to_bgr888_result *result = ¶ms->bgr888_result; + size_t dst_size; + u8 *buf = NULL; + __le32 *xrgb8888 = NULL; + struct iosys_map dst, src; + + struct drm_framebuffer fb = { + .format = drm_format_info(DRM_FORMAT_XRGB8888), + .pitches = { params->pitch, 0, 0 }, + }; + + dst_size = conversion_buf_size(DRM_FORMAT_BGR888, result->dst_pitch, + ¶ms->clip, 0); + KUNIT_ASSERT_GT(test, dst_size, 0); + + buf = kunit_kzalloc(test, dst_size, GFP_KERNEL); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, buf); + iosys_map_set_vaddr(&dst, buf); + + xrgb8888 = cpubuf_to_le32(test, params->xrgb8888, TEST_BUF_SIZE); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, xrgb8888); + iosys_map_set_vaddr(&src, xrgb8888); + + /* + * BGR888 expected results are already in little-endian + * order, so there's no need to convert the test output. + */ + drm_fb_xrgb8888_to_bgr888(&dst, &result->dst_pitch, &src, &fb, ¶ms->clip, + &fmtcnv_state); + KUNIT_EXPECT_MEMEQ(test, buf, result->expected, dst_size); + + buf = dst.vaddr; /* restore original value of buf */ + memset(buf, 0, dst_size); + + int blit_result = 0; + + blit_result = drm_fb_blit(&dst, &result->dst_pitch, DRM_FORMAT_BGR888, &src, &fb, ¶ms->clip, + &fmtcnv_state); + + KUNIT_EXPECT_FALSE(test, blit_result); + KUNIT_EXPECT_MEMEQ(test, buf, result->expected, dst_size); +} + static void drm_test_fb_xrgb8888_to_argb8888(struct kunit *test) { const struct convert_xrgb8888_case *params = test->param_value; @@ -1851,6 +1931,7 @@ static struct kunit_case drm_format_helper_test_cases[] = { KUNIT_CASE_PARAM(drm_test_fb_xrgb8888_to_argb1555, convert_xrgb8888_gen_params), KUNIT_CASE_PARAM(drm_test_fb_xrgb8888_to_rgba5551, convert_xrgb8888_gen_params), KUNIT_CASE_PARAM(drm_test_fb_xrgb8888_to_rgb888, convert_xrgb8888_gen_params), + KUNIT_CASE_PARAM(drm_test_fb_xrgb8888_to_bgr888, convert_xrgb8888_gen_params), KUNIT_CASE_PARAM(drm_test_fb_xrgb8888_to_argb8888, convert_xrgb8888_gen_params), KUNIT_CASE_PARAM(drm_test_fb_xrgb8888_to_xrgb2101010, convert_xrgb8888_gen_params), KUNIT_CASE_PARAM(drm_test_fb_xrgb8888_to_argb2101010, convert_xrgb8888_gen_params), diff --git a/drivers/gpu/drm/tiny/Kconfig b/drivers/gpu/drm/tiny/Kconfig index 94cbdb1337c0..54c84c9801c1 100644 --- a/drivers/gpu/drm/tiny/Kconfig +++ b/drivers/gpu/drm/tiny/Kconfig @@ -1,5 +1,17 @@ # SPDX-License-Identifier: GPL-2.0-only +config DRM_APPLETBDRM + tristate "DRM support for Apple Touch Bars" + depends on DRM && USB && MMU + select DRM_GEM_SHMEM_HELPER + select DRM_KMS_HELPER + help + Say Y here if you want support for the display of Touch Bars on x86 + MacBook Pros. + + To compile this driver as a module, choose M here: the + module will be called appletbdrm. + config DRM_ARCPGU tristate "ARC PGU" depends on DRM && OF diff --git a/drivers/gpu/drm/tiny/Makefile b/drivers/gpu/drm/tiny/Makefile index 60816d2eb4ff..0a3a7837a58b 100644 --- a/drivers/gpu/drm/tiny/Makefile +++ b/drivers/gpu/drm/tiny/Makefile @@ -1,5 +1,6 @@ # SPDX-License-Identifier: GPL-2.0-only +obj-$(CONFIG_DRM_APPLETBDRM) += appletbdrm.o obj-$(CONFIG_DRM_ARCPGU) += arcpgu.o obj-$(CONFIG_DRM_BOCHS) += bochs.o obj-$(CONFIG_DRM_CIRRUS_QEMU) += cirrus-qemu.o diff --git a/drivers/gpu/drm/tiny/appletbdrm.c b/drivers/gpu/drm/tiny/appletbdrm.c new file mode 100644 index 000000000000..4370ba22dd88 --- /dev/null +++ b/drivers/gpu/drm/tiny/appletbdrm.c @@ -0,0 +1,840 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Apple Touch Bar DRM Driver + * + * Copyright (c) 2023 Kerem Karabay + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define APPLETBDRM_PIXEL_FORMAT cpu_to_le32(0x52474241) /* RGBA, the actual format is BGR888 */ +#define APPLETBDRM_BITS_PER_PIXEL 24 + +#define APPLETBDRM_MSG_CLEAR_DISPLAY cpu_to_le32(0x434c5244) /* CLRD */ +#define APPLETBDRM_MSG_GET_INFORMATION cpu_to_le32(0x47494e46) /* GINF */ +#define APPLETBDRM_MSG_UPDATE_COMPLETE cpu_to_le32(0x5544434c) /* UDCL */ +#define APPLETBDRM_MSG_SIGNAL_READINESS cpu_to_le32(0x52454459) /* REDY */ + +#define APPLETBDRM_BULK_MSG_TIMEOUT 1000 + +#define drm_to_adev(_drm) container_of(_drm, struct appletbdrm_device, drm) +#define adev_to_udev(adev) interface_to_usbdev(to_usb_interface(adev->dmadev)) + +struct appletbdrm_msg_request_header { + __le16 unk_00; + __le16 unk_02; + __le32 unk_04; + __le32 unk_08; + __le32 size; +} __packed; + +struct appletbdrm_msg_response_header { + u8 unk_00[16]; + __le32 msg; +} __packed; + +struct appletbdrm_msg_simple_request { + struct appletbdrm_msg_request_header header; + __le32 msg; + u8 unk_14[8]; + __le32 size; +} __packed; + +struct appletbdrm_msg_information { + struct appletbdrm_msg_response_header header; + u8 unk_14[12]; + __le32 width; + __le32 height; + u8 bits_per_pixel; + __le32 bytes_per_row; + __le32 orientation; + __le32 bitmap_info; + __le32 pixel_format; + __le32 width_inches; /* floating point */ + __le32 height_inches; /* floating point */ +} __packed; + +struct appletbdrm_frame { + __le16 begin_x; + __le16 begin_y; + __le16 width; + __le16 height; + __le32 buf_size; + u8 buf[]; +} __packed; + +struct appletbdrm_fb_request_footer { + u8 unk_00[12]; + __le32 unk_0c; + u8 unk_10[12]; + __le32 unk_1c; + __le64 timestamp; + u8 unk_28[12]; + __le32 unk_34; + u8 unk_38[20]; + __le32 unk_4c; +} __packed; + +struct appletbdrm_fb_request { + struct appletbdrm_msg_request_header header; + __le16 unk_10; + u8 msg_id; + u8 unk_13[29]; + /* + * Contents of `data`: + * - struct appletbdrm_frame frames[]; + * - struct appletbdrm_fb_request_footer footer; + * - padding to make the total size a multiple of 16 + */ + u8 data[]; +} __packed; + +struct appletbdrm_fb_request_response { + struct appletbdrm_msg_response_header header; + u8 unk_14[12]; + __le64 timestamp; +} __packed; + +struct appletbdrm_device { + struct device *dmadev; + + unsigned int in_ep; + unsigned int out_ep; + + unsigned int width; + unsigned int height; + + struct drm_device drm; + struct drm_display_mode mode; + struct drm_connector connector; + struct drm_plane primary_plane; + struct drm_crtc crtc; + struct drm_encoder encoder; +}; + +struct appletbdrm_plane_state { + struct drm_shadow_plane_state base; + struct appletbdrm_fb_request *request; + struct appletbdrm_fb_request_response *response; + size_t request_size; + size_t frames_size; +}; + +static inline struct appletbdrm_plane_state *to_appletbdrm_plane_state(struct drm_plane_state *state) +{ + return container_of(state, struct appletbdrm_plane_state, base.base); +} + +static int appletbdrm_send_request(struct appletbdrm_device *adev, + struct appletbdrm_msg_request_header *request, size_t size) +{ + struct usb_device *udev = adev_to_udev(adev); + struct drm_device *drm = &adev->drm; + int ret, actual_size; + + ret = usb_bulk_msg(udev, usb_sndbulkpipe(udev, adev->out_ep), + request, size, &actual_size, APPLETBDRM_BULK_MSG_TIMEOUT); + if (ret) { + drm_err(drm, "Failed to send message (%d)\n", ret); + return ret; + } + + if (actual_size != size) { + drm_err(drm, "Actual size (%d) doesn't match expected size (%zu)\n", + actual_size, size); + return -EIO; + } + + return 0; +} + +static int appletbdrm_read_response(struct appletbdrm_device *adev, + struct appletbdrm_msg_response_header *response, + size_t size, __le32 expected_response) +{ + struct usb_device *udev = adev_to_udev(adev); + struct drm_device *drm = &adev->drm; + int ret, actual_size; + bool readiness_signal_received = false; + +retry: + ret = usb_bulk_msg(udev, usb_rcvbulkpipe(udev, adev->in_ep), + response, size, &actual_size, APPLETBDRM_BULK_MSG_TIMEOUT); + if (ret) { + drm_err(drm, "Failed to read response (%d)\n", ret); + return ret; + } + + /* + * The device responds to the first request sent in a particular + * timeframe after the USB device configuration is set with a readiness + * signal, in which case the response should be read again + */ + if (response->msg == APPLETBDRM_MSG_SIGNAL_READINESS) { + if (!readiness_signal_received) { + readiness_signal_received = true; + goto retry; + } + + drm_err(drm, "Encountered unexpected readiness signal\n"); + return -EINTR; + } + + if (actual_size != size) { + drm_err(drm, "Actual size (%d) doesn't match expected size (%zu)\n", + actual_size, size); + return -EBADMSG; + } + + if (response->msg != expected_response) { + drm_err(drm, "Unexpected response from device (expected %p4cc found %p4cc)\n", + &expected_response, &response->msg); + return -EIO; + } + + return 0; +} + +static int appletbdrm_send_msg(struct appletbdrm_device *adev, __le32 msg) +{ + struct appletbdrm_msg_simple_request *request; + int ret; + + request = kzalloc(sizeof(*request), GFP_KERNEL); + if (!request) + return -ENOMEM; + + request->header.unk_00 = cpu_to_le16(2); + request->header.unk_02 = cpu_to_le16(0x1512); + request->header.size = cpu_to_le32(sizeof(*request) - sizeof(request->header)); + request->msg = msg; + request->size = request->header.size; + + ret = appletbdrm_send_request(adev, &request->header, sizeof(*request)); + + kfree(request); + + return ret; +} + +static int appletbdrm_clear_display(struct appletbdrm_device *adev) +{ + return appletbdrm_send_msg(adev, APPLETBDRM_MSG_CLEAR_DISPLAY); +} + +static int appletbdrm_signal_readiness(struct appletbdrm_device *adev) +{ + return appletbdrm_send_msg(adev, APPLETBDRM_MSG_SIGNAL_READINESS); +} + +static int appletbdrm_get_information(struct appletbdrm_device *adev) +{ + struct appletbdrm_msg_information *info; + struct drm_device *drm = &adev->drm; + u8 bits_per_pixel; + __le32 pixel_format; + int ret; + + info = kzalloc(sizeof(*info), GFP_KERNEL); + if (!info) + return -ENOMEM; + + ret = appletbdrm_send_msg(adev, APPLETBDRM_MSG_GET_INFORMATION); + if (ret) + return ret; + + ret = appletbdrm_read_response(adev, &info->header, sizeof(*info), + APPLETBDRM_MSG_GET_INFORMATION); + if (ret) + goto free_info; + + bits_per_pixel = info->bits_per_pixel; + pixel_format = get_unaligned(&info->pixel_format); + + adev->width = get_unaligned_le32(&info->width); + adev->height = get_unaligned_le32(&info->height); + + if (bits_per_pixel != APPLETBDRM_BITS_PER_PIXEL) { + drm_err(drm, "Encountered unexpected bits per pixel value (%d)\n", bits_per_pixel); + ret = -EINVAL; + goto free_info; + } + + if (pixel_format != APPLETBDRM_PIXEL_FORMAT) { + drm_err(drm, "Encountered unknown pixel format (%p4cc)\n", &pixel_format); + ret = -EINVAL; + goto free_info; + } + +free_info: + kfree(info); + + return ret; +} + +static u32 rect_size(struct drm_rect *rect) +{ + return drm_rect_width(rect) * drm_rect_height(rect) * + (BITS_TO_BYTES(APPLETBDRM_BITS_PER_PIXEL)); +} + +static int appletbdrm_connector_helper_get_modes(struct drm_connector *connector) +{ + struct appletbdrm_device *adev = drm_to_adev(connector->dev); + + return drm_connector_helper_get_modes_fixed(connector, &adev->mode); +} + +static const u32 appletbdrm_primary_plane_formats[] = { + DRM_FORMAT_BGR888, + DRM_FORMAT_XRGB8888, /* emulated */ +}; + +static int appletbdrm_primary_plane_helper_atomic_check(struct drm_plane *plane, + struct drm_atomic_state *state) +{ + struct drm_plane_state *new_plane_state = drm_atomic_get_new_plane_state(state, plane); + struct drm_plane_state *old_plane_state = drm_atomic_get_old_plane_state(state, plane); + struct drm_crtc *new_crtc = new_plane_state->crtc; + struct drm_crtc_state *new_crtc_state = NULL; + struct appletbdrm_plane_state *appletbdrm_state = to_appletbdrm_plane_state(new_plane_state); + struct drm_atomic_helper_damage_iter iter; + struct drm_rect damage; + size_t frames_size = 0; + size_t request_size; + int ret; + + if (new_crtc) + new_crtc_state = drm_atomic_get_new_crtc_state(state, new_crtc); + + ret = drm_atomic_helper_check_plane_state(new_plane_state, new_crtc_state, + DRM_PLANE_NO_SCALING, + DRM_PLANE_NO_SCALING, + false, false); + if (ret) + return ret; + else if (!new_plane_state->visible) + return 0; + + drm_atomic_helper_damage_iter_init(&iter, old_plane_state, new_plane_state); + drm_atomic_for_each_plane_damage(&iter, &damage) { + frames_size += struct_size((struct appletbdrm_frame *)0, buf, rect_size(&damage)); + } + + if (!frames_size) + return 0; + + request_size = ALIGN(sizeof(struct appletbdrm_fb_request) + + frames_size + + sizeof(struct appletbdrm_fb_request_footer), 16); + + appletbdrm_state->request = kzalloc(request_size, GFP_KERNEL); + + if (!appletbdrm_state->request) + return -ENOMEM; + + appletbdrm_state->response = kzalloc(sizeof(*appletbdrm_state->response), GFP_KERNEL); + + if (!appletbdrm_state->response) + return -ENOMEM; + + appletbdrm_state->request_size = request_size; + appletbdrm_state->frames_size = frames_size; + + return 0; +} + +static int appletbdrm_flush_damage(struct appletbdrm_device *adev, + struct drm_plane_state *old_state, + struct drm_plane_state *state) +{ + struct appletbdrm_plane_state *appletbdrm_state = to_appletbdrm_plane_state(state); + struct drm_shadow_plane_state *shadow_plane_state = to_drm_shadow_plane_state(state); + struct appletbdrm_fb_request_response *response = appletbdrm_state->response; + struct appletbdrm_fb_request_footer *footer; + struct drm_atomic_helper_damage_iter iter; + struct drm_framebuffer *fb = state->fb; + struct appletbdrm_fb_request *request = appletbdrm_state->request; + struct drm_device *drm = &adev->drm; + struct appletbdrm_frame *frame; + u64 timestamp = ktime_get_ns(); + struct drm_rect damage; + size_t frames_size = appletbdrm_state->frames_size; + size_t request_size = appletbdrm_state->request_size; + int ret; + + if (!frames_size) + return 0; + + ret = drm_gem_fb_begin_cpu_access(fb, DMA_FROM_DEVICE); + if (ret) { + drm_err(drm, "Failed to start CPU framebuffer access (%d)\n", ret); + goto end_fb_cpu_access; + } + + request->header.unk_00 = cpu_to_le16(2); + request->header.unk_02 = cpu_to_le16(0x12); + request->header.unk_04 = cpu_to_le32(9); + request->header.size = cpu_to_le32(request_size - sizeof(request->header)); + request->unk_10 = cpu_to_le16(1); + request->msg_id = timestamp; + + frame = (struct appletbdrm_frame *)request->data; + + drm_atomic_helper_damage_iter_init(&iter, old_state, state); + drm_atomic_for_each_plane_damage(&iter, &damage) { + struct drm_rect dst_clip = state->dst; + struct iosys_map dst = IOSYS_MAP_INIT_VADDR(frame->buf); + u32 buf_size = rect_size(&damage); + + if (!drm_rect_intersect(&dst_clip, &damage)) + continue; + + /* + * The coordinates need to be translated to the coordinate + * system the device expects, see the comment in + * appletbdrm_setup_mode_config + */ + frame->begin_x = cpu_to_le16(damage.y1); + frame->begin_y = cpu_to_le16(adev->height - damage.x2); + frame->width = cpu_to_le16(drm_rect_height(&damage)); + frame->height = cpu_to_le16(drm_rect_width(&damage)); + frame->buf_size = cpu_to_le32(buf_size); + + switch (fb->format->format) { + case DRM_FORMAT_XRGB8888: + drm_fb_xrgb8888_to_bgr888(&dst, NULL, &shadow_plane_state->data[0], fb, &damage, &shadow_plane_state->fmtcnv_state); + break; + default: + drm_fb_memcpy(&dst, NULL, &shadow_plane_state->data[0], fb, &damage); + break; + } + + frame = (void *)frame + struct_size(frame, buf, buf_size); + } + + footer = (struct appletbdrm_fb_request_footer *)&request->data[frames_size]; + + footer->unk_0c = cpu_to_le32(0xfffe); + footer->unk_1c = cpu_to_le32(0x80001); + footer->unk_34 = cpu_to_le32(0x80002); + footer->unk_4c = cpu_to_le32(0xffff); + footer->timestamp = cpu_to_le64(timestamp); + + ret = appletbdrm_send_request(adev, &request->header, request_size); + if (ret) + goto end_fb_cpu_access; + + ret = appletbdrm_read_response(adev, &response->header, sizeof(*response), + APPLETBDRM_MSG_UPDATE_COMPLETE); + if (ret) + goto end_fb_cpu_access; + + if (response->timestamp != footer->timestamp) { + drm_err(drm, "Response timestamp (%llu) doesn't match request timestamp (%llu)\n", + le64_to_cpu(response->timestamp), timestamp); + goto end_fb_cpu_access; + } + +end_fb_cpu_access: + drm_gem_fb_end_cpu_access(fb, DMA_FROM_DEVICE); + + return ret; +} + +static void appletbdrm_primary_plane_helper_atomic_update(struct drm_plane *plane, + struct drm_atomic_state *old_state) +{ + struct appletbdrm_device *adev = drm_to_adev(plane->dev); + struct drm_device *drm = plane->dev; + struct drm_plane_state *plane_state = plane->state; + struct drm_plane_state *old_plane_state = drm_atomic_get_old_plane_state(old_state, plane); + int idx; + + if (!drm_dev_enter(drm, &idx)) + return; + + appletbdrm_flush_damage(adev, old_plane_state, plane_state); + + drm_dev_exit(idx); +} + +static void appletbdrm_primary_plane_helper_atomic_disable(struct drm_plane *plane, + struct drm_atomic_state *state) +{ + struct drm_device *dev = plane->dev; + struct appletbdrm_device *adev = drm_to_adev(dev); + int idx; + + if (!drm_dev_enter(dev, &idx)) + return; + + appletbdrm_clear_display(adev); + + drm_dev_exit(idx); +} + +static void appletbdrm_primary_plane_reset(struct drm_plane *plane) +{ + struct appletbdrm_plane_state *appletbdrm_state; + + WARN_ON(plane->state); + + appletbdrm_state = kzalloc(sizeof(*appletbdrm_state), GFP_KERNEL); + if (!appletbdrm_state) + return; + + __drm_gem_reset_shadow_plane(plane, &appletbdrm_state->base); +} + +static struct drm_plane_state *appletbdrm_primary_plane_duplicate_state(struct drm_plane *plane) +{ + struct drm_shadow_plane_state *new_shadow_plane_state; + struct appletbdrm_plane_state *appletbdrm_state; + + if (WARN_ON(!plane->state)) + return NULL; + + appletbdrm_state = kzalloc(sizeof(*appletbdrm_state), GFP_KERNEL); + if (!appletbdrm_state) + return NULL; + + /* Request and response are not duplicated and are allocated in .atomic_check */ + appletbdrm_state->request = NULL; + appletbdrm_state->response = NULL; + + appletbdrm_state->request_size = 0; + appletbdrm_state->frames_size = 0; + + new_shadow_plane_state = &appletbdrm_state->base; + + __drm_gem_duplicate_shadow_plane_state(plane, new_shadow_plane_state); + + return &new_shadow_plane_state->base; +} + +static void appletbdrm_primary_plane_destroy_state(struct drm_plane *plane, + struct drm_plane_state *state) +{ + struct appletbdrm_plane_state *appletbdrm_state = to_appletbdrm_plane_state(state); + + kfree(appletbdrm_state->request); + kfree(appletbdrm_state->response); + + __drm_gem_destroy_shadow_plane_state(&appletbdrm_state->base); + + kfree(appletbdrm_state); +} + +static const struct drm_plane_helper_funcs appletbdrm_primary_plane_helper_funcs = { + DRM_GEM_SHADOW_PLANE_HELPER_FUNCS, + .atomic_check = appletbdrm_primary_plane_helper_atomic_check, + .atomic_update = appletbdrm_primary_plane_helper_atomic_update, + .atomic_disable = appletbdrm_primary_plane_helper_atomic_disable, +}; + +static const struct drm_plane_funcs appletbdrm_primary_plane_funcs = { + .update_plane = drm_atomic_helper_update_plane, + .disable_plane = drm_atomic_helper_disable_plane, + .reset = appletbdrm_primary_plane_reset, + .atomic_duplicate_state = appletbdrm_primary_plane_duplicate_state, + .atomic_destroy_state = appletbdrm_primary_plane_destroy_state, + .destroy = drm_plane_cleanup, +}; + +static enum drm_mode_status appletbdrm_crtc_helper_mode_valid(struct drm_crtc *crtc, + const struct drm_display_mode *mode) +{ + struct appletbdrm_device *adev = drm_to_adev(crtc->dev); + + return drm_crtc_helper_mode_valid_fixed(crtc, mode, &adev->mode); +} + +static const struct drm_mode_config_funcs appletbdrm_mode_config_funcs = { + .fb_create = drm_gem_fb_create_with_dirty, + .atomic_check = drm_atomic_helper_check, + .atomic_commit = drm_atomic_helper_commit, +}; + +static const struct drm_connector_funcs appletbdrm_connector_funcs = { + .reset = drm_atomic_helper_connector_reset, + .destroy = drm_connector_cleanup, + .fill_modes = drm_helper_probe_single_connector_modes, + .atomic_destroy_state = drm_atomic_helper_connector_destroy_state, + .atomic_duplicate_state = drm_atomic_helper_connector_duplicate_state, +}; + +static const struct drm_connector_helper_funcs appletbdrm_connector_helper_funcs = { + .get_modes = appletbdrm_connector_helper_get_modes, +}; + +static const struct drm_crtc_helper_funcs appletbdrm_crtc_helper_funcs = { + .mode_valid = appletbdrm_crtc_helper_mode_valid, +}; + +static const struct drm_crtc_funcs appletbdrm_crtc_funcs = { + .reset = drm_atomic_helper_crtc_reset, + .destroy = drm_crtc_cleanup, + .set_config = drm_atomic_helper_set_config, + .page_flip = drm_atomic_helper_page_flip, + .atomic_duplicate_state = drm_atomic_helper_crtc_duplicate_state, + .atomic_destroy_state = drm_atomic_helper_crtc_destroy_state, +}; + +static const struct drm_encoder_funcs appletbdrm_encoder_funcs = { + .destroy = drm_encoder_cleanup, +}; + +static struct drm_gem_object *appletbdrm_driver_gem_prime_import(struct drm_device *dev, + struct dma_buf *dma_buf) +{ + struct appletbdrm_device *adev = drm_to_adev(dev); + + if (!adev->dmadev) + return ERR_PTR(-ENODEV); + + return drm_gem_prime_import_dev(dev, dma_buf, adev->dmadev); +} + +DEFINE_DRM_GEM_FOPS(appletbdrm_drm_fops); + +static const struct drm_driver appletbdrm_drm_driver = { + DRM_GEM_SHMEM_DRIVER_OPS, + .gem_prime_import = appletbdrm_driver_gem_prime_import, + .name = "appletbdrm", + .desc = "Apple Touch Bar DRM Driver", + .major = 1, + .minor = 0, + .driver_features = DRIVER_MODESET | DRIVER_GEM | DRIVER_ATOMIC, + .fops = &appletbdrm_drm_fops, +}; + +static int appletbdrm_setup_mode_config(struct appletbdrm_device *adev) +{ + struct drm_connector *connector = &adev->connector; + struct drm_plane *primary_plane; + struct drm_crtc *crtc; + struct drm_encoder *encoder; + struct drm_device *drm = &adev->drm; + int ret; + + ret = drmm_mode_config_init(drm); + if (ret) { + drm_err(drm, "Failed to initialize mode configuration\n"); + return ret; + } + + primary_plane = &adev->primary_plane; + ret = drm_universal_plane_init(drm, primary_plane, 0, + &appletbdrm_primary_plane_funcs, + appletbdrm_primary_plane_formats, + ARRAY_SIZE(appletbdrm_primary_plane_formats), + NULL, + DRM_PLANE_TYPE_PRIMARY, NULL); + if (ret) { + drm_err(drm, "Failed to initialize universal plane object\n"); + return ret; + } + + drm_plane_helper_add(primary_plane, &appletbdrm_primary_plane_helper_funcs); + drm_plane_enable_fb_damage_clips(primary_plane); + + crtc = &adev->crtc; + ret = drm_crtc_init_with_planes(drm, crtc, primary_plane, NULL, + &appletbdrm_crtc_funcs, NULL); + if (ret) { + drm_err(drm, "Failed to initialize CRTC object\n"); + return ret; + } + + drm_crtc_helper_add(crtc, &appletbdrm_crtc_helper_funcs); + + encoder = &adev->encoder; + ret = drm_encoder_init(drm, encoder, &appletbdrm_encoder_funcs, + DRM_MODE_ENCODER_DAC, NULL); + if (ret) { + drm_err(drm, "Failed to initialize encoder\n"); + return ret; + } + + encoder->possible_crtcs = drm_crtc_mask(crtc); + + /* + * The coordinate system used by the device is different from the + * coordinate system of the framebuffer in that the x and y axes are + * swapped, and that the y axis is inverted; so what the device reports + * as the height is actually the width of the framebuffer and vice + * versa. + */ + drm->mode_config.max_width = max(adev->height, DRM_SHADOW_PLANE_MAX_WIDTH); + drm->mode_config.max_height = max(adev->width, DRM_SHADOW_PLANE_MAX_HEIGHT); + drm->mode_config.preferred_depth = APPLETBDRM_BITS_PER_PIXEL; + drm->mode_config.funcs = &appletbdrm_mode_config_funcs; + + adev->mode = (struct drm_display_mode) { + DRM_MODE_INIT(60, adev->height, adev->width, + DRM_MODE_RES_MM(adev->height, 218), + DRM_MODE_RES_MM(adev->width, 218)) + }; + + ret = drm_connector_init(drm, connector, + &appletbdrm_connector_funcs, DRM_MODE_CONNECTOR_USB); + if (ret) { + drm_err(drm, "Failed to initialize connector\n"); + return ret; + } + + drm_connector_helper_add(connector, &appletbdrm_connector_helper_funcs); + + ret = drm_connector_set_panel_orientation(connector, + DRM_MODE_PANEL_ORIENTATION_RIGHT_UP); + if (ret) { + drm_err(drm, "Failed to set panel orientation\n"); + return ret; + } + + connector->display_info.non_desktop = true; + ret = drm_object_property_set_value(&connector->base, + drm->mode_config.non_desktop_property, true); + if (ret) { + drm_err(drm, "Failed to set non-desktop property\n"); + return ret; + } + + ret = drm_connector_attach_encoder(connector, encoder); + + if (ret) { + drm_err(drm, "Failed to initialize simple display pipe\n"); + return ret; + } + + drm_mode_config_reset(drm); + + return 0; +} + +static int appletbdrm_probe(struct usb_interface *intf, + const struct usb_device_id *id) +{ + struct usb_endpoint_descriptor *bulk_in, *bulk_out; + struct device *dev = &intf->dev; + struct appletbdrm_device *adev; + struct drm_device *drm = NULL; + int ret; + + ret = usb_find_common_endpoints(intf->cur_altsetting, &bulk_in, &bulk_out, NULL, NULL); + if (ret) { + drm_err(drm, "appletbdrm: Failed to find bulk endpoints\n"); + return ret; + } + + adev = devm_drm_dev_alloc(dev, &appletbdrm_drm_driver, struct appletbdrm_device, drm); + if (IS_ERR(adev)) + return PTR_ERR(adev); + + adev->in_ep = bulk_in->bEndpointAddress; + adev->out_ep = bulk_out->bEndpointAddress; + adev->dmadev = dev; + + drm = &adev->drm; + + usb_set_intfdata(intf, adev); + + ret = appletbdrm_get_information(adev); + if (ret) { + drm_err(drm, "Failed to get display information\n"); + return ret; + } + + ret = appletbdrm_signal_readiness(adev); + if (ret) { + drm_err(drm, "Failed to signal readiness\n"); + return ret; + } + + ret = appletbdrm_setup_mode_config(adev); + if (ret) { + drm_err(drm, "Failed to setup mode config\n"); + return ret; + } + + ret = drm_dev_register(drm, 0); + if (ret) { + drm_err(drm, "Failed to register DRM device\n"); + return ret; + } + + ret = appletbdrm_clear_display(adev); + if (ret) { + drm_err(drm, "Failed to clear display\n"); + return ret; + } + + return 0; +} + +static void appletbdrm_disconnect(struct usb_interface *intf) +{ + struct appletbdrm_device *adev = usb_get_intfdata(intf); + struct drm_device *drm = &adev->drm; + + drm_dev_unplug(drm); + drm_atomic_helper_shutdown(drm); +} + +static void appletbdrm_shutdown(struct usb_interface *intf) +{ + struct appletbdrm_device *adev = usb_get_intfdata(intf); + + /* + * The framebuffer needs to be cleared on shutdown since its content + * persists across boots + */ + drm_atomic_helper_shutdown(&adev->drm); +} + +static const struct usb_device_id appletbdrm_usb_id_table[] = { + { USB_DEVICE_INTERFACE_CLASS(0x05ac, 0x8302, USB_CLASS_AUDIO_VIDEO) }, + {} +}; +MODULE_DEVICE_TABLE(usb, appletbdrm_usb_id_table); + +static struct usb_driver appletbdrm_usb_driver = { + .name = "appletbdrm", + .probe = appletbdrm_probe, + .disconnect = appletbdrm_disconnect, + .shutdown = appletbdrm_shutdown, + .id_table = appletbdrm_usb_id_table, +}; +module_usb_driver(appletbdrm_usb_driver); + +MODULE_AUTHOR("Kerem Karabay "); +MODULE_DESCRIPTION("Apple Touch Bar DRM Driver"); +MODULE_LICENSE("GPL"); diff --git a/drivers/gpu/vga/vga_switcheroo.c b/drivers/gpu/vga/vga_switcheroo.c index 18f2c92beff8..3de1bca45ed2 100644 --- a/drivers/gpu/vga/vga_switcheroo.c +++ b/drivers/gpu/vga/vga_switcheroo.c @@ -438,12 +438,7 @@ find_active_client(struct list_head *head) bool vga_switcheroo_client_probe_defer(struct pci_dev *pdev) { if ((pdev->class >> 16) == PCI_BASE_CLASS_DISPLAY) { - /* - * apple-gmux is needed on pre-retina MacBook Pro - * to probe the panel if pdev is the inactive GPU. - */ - if (apple_gmux_present() && pdev != vga_default_device() && - !vgasr_priv.handler_flags) + if (apple_gmux_present() && !vgasr_priv.handler_flags) return true; } diff --git a/drivers/hid/Kconfig b/drivers/hid/Kconfig index d979b18f7f5b..d9364bbfb76c 100644 --- a/drivers/hid/Kconfig +++ b/drivers/hid/Kconfig @@ -129,7 +129,7 @@ config HID_APPLE tristate "Apple {i,Power,Mac}Books" depends on LEDS_CLASS depends on NEW_LEDS - default !EXPERT + default !EXPERT || SPI_HID_APPLE help Support for some Apple devices which less or more break HID specification. @@ -148,6 +148,31 @@ config HID_APPLEIR Say Y here if you want support for Apple infrared remote control. +config HID_APPLETB_BL + tristate "Apple Touch Bar Backlight" + depends on BACKLIGHT_CLASS_DEVICE + help + Say Y here if you want support for the backlight of Touch Bars on x86 + MacBook Pros. + + To compile this driver as a module, choose M here: the + module will be called hid-appletb-bl. + +config HID_APPLETB_KBD + tristate "Apple Touch Bar Keyboard Mode" + depends on USB_HID + depends on BACKLIGHT_CLASS_DEVICE + depends on INPUT + select INPUT_SPARSEKMAP + select HID_APPLETB_BL + help + Say Y here if you want support for the keyboard mode (escape, + function, media and brightness keys) of Touch Bars on x86 MacBook + Pros. + + To compile this driver as a module, choose M here: the + module will be called hid-appletb-kbd. + config HID_ASUS tristate "Asus" depends on USB_HID @@ -698,11 +723,13 @@ config LOGIWHEELS_FF config HID_MAGICMOUSE tristate "Apple Magic Mouse/Trackpad multi-touch support" + default SPI_HID_APPLE help Support for the Apple Magic Mouse/Trackpad multi-touch. Say Y here if you want support for the multi-touch features of the - Apple Wireless "Magic" Mouse and the Apple Wireless "Magic" Trackpad. + Apple Wireless "Magic" Mouse, the Apple Wireless "Magic" Trackpad and + force touch Trackpads in Macbooks starting from 2015. config HID_MALTRON tristate "Maltron L90 keyboard" @@ -752,6 +779,7 @@ config HID_MULTITOUCH Say Y here if you have one of the following devices: - 3M PCT touch screens - ActionStar dual touch panels + - Touch Bars on x86 MacBook Pros - Atmel panels - Cando dual touch panels - Chunghwa panels @@ -1416,4 +1444,8 @@ endif # HID source "drivers/hid/usbhid/Kconfig" +source "drivers/hid/spi-hid/Kconfig" + +source "drivers/hid/dockchannel-hid/Kconfig" + endif # HID_SUPPORT diff --git a/drivers/hid/Makefile b/drivers/hid/Makefile index 733ab7cc5813..1aa7713ae9f8 100644 --- a/drivers/hid/Makefile +++ b/drivers/hid/Makefile @@ -29,6 +29,8 @@ obj-$(CONFIG_HID_ALPS) += hid-alps.o obj-$(CONFIG_HID_ACRUX) += hid-axff.o obj-$(CONFIG_HID_APPLE) += hid-apple.o obj-$(CONFIG_HID_APPLEIR) += hid-appleir.o +obj-$(CONFIG_HID_APPLETB_BL) += hid-appletb-bl.o +obj-$(CONFIG_HID_APPLETB_KBD) += hid-appletb-kbd.o obj-$(CONFIG_HID_CREATIVE_SB0540) += hid-creative-sb0540.o obj-$(CONFIG_HID_ASUS) += hid-asus.o obj-$(CONFIG_HID_ASUS_ALLY) += hid-asus-ally.o @@ -174,3 +176,7 @@ obj-$(CONFIG_AMD_SFH_HID) += amd-sfh-hid/ obj-$(CONFIG_SURFACE_HID_CORE) += surface-hid/ obj-$(CONFIG_INTEL_THC_HID) += intel-thc-hid/ + +obj-$(CONFIG_SPI_HID_APPLE_CORE) += spi-hid/ + +obj-$(CONFIG_HID_DOCKCHANNEL) += dockchannel-hid/ diff --git a/drivers/hid/dockchannel-hid/Kconfig b/drivers/hid/dockchannel-hid/Kconfig new file mode 100644 index 000000000000..8a81d551a83d --- /dev/null +++ b/drivers/hid/dockchannel-hid/Kconfig @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: GPL-2.0-only OR MIT +menu "DockChannel HID support" + depends on APPLE_DOCKCHANNEL + +config HID_DOCKCHANNEL + tristate "HID over DockChannel transport layer for Apple Silicon SoCs" + default ARCH_APPLE + depends on APPLE_DOCKCHANNEL && INPUT && OF && HID + help + Say Y here if you use an M2 or later Apple Silicon based laptop. + The keyboard and touchpad are HID based devices connected via the + proprietary DockChannel interface. + +endmenu diff --git a/drivers/hid/dockchannel-hid/Makefile b/drivers/hid/dockchannel-hid/Makefile new file mode 100644 index 000000000000..7dba766b047f --- /dev/null +++ b/drivers/hid/dockchannel-hid/Makefile @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: GPL-2.0-only OR MIT +# +# Makefile for DockChannel HID transport drivers +# + +obj-$(CONFIG_HID_DOCKCHANNEL) += dockchannel-hid.o diff --git a/drivers/hid/dockchannel-hid/dockchannel-hid.c b/drivers/hid/dockchannel-hid/dockchannel-hid.c new file mode 100644 index 000000000000..a712a724ded3 --- /dev/null +++ b/drivers/hid/dockchannel-hid/dockchannel-hid.c @@ -0,0 +1,1213 @@ +/* + * SPDX-License-Identifier: GPL-2.0 OR MIT + * + * Apple DockChannel HID transport driver + * + * Copyright The Asahi Linux Contributors + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "../hid-ids.h" + +#define COMMAND_TIMEOUT_MS 1000 +#define START_TIMEOUT_MS 2000 + +#define MAX_INTERFACES 16 + +/* Data + checksum */ +#define MAX_PKT_SIZE (0xffff + 4) + +#define DCHID_CHANNEL_CMD 0x11 +#define DCHID_CHANNEL_REPORT 0x12 + +struct dchid_hdr { + u8 hdr_len; + u8 channel; + u16 length; + u8 seq; + u8 iface; + u16 pad; +} __packed; + +#define IFACE_COMM 0 + +#define FLAGS_GROUP GENMASK(7, 6) +#define FLAGS_REQ GENMASK(5, 0) + +#define REQ_SET_REPORT 0 +#define REQ_GET_REPORT 1 + +struct dchid_subhdr { + u8 flags; + u8 unk; + u16 length; + u32 retcode; +} __packed; + +#define EVENT_GPIO_CMD 0xa0 +#define EVENT_INIT 0xf0 +#define EVENT_READY 0xf1 + +struct dchid_init_hdr { + u8 type; + u8 unk1; + u8 unk2; + u8 iface; + char name[16]; + u8 more_packets; + u8 unkpad; +} __packed; + +#define INIT_HID_DESCRIPTOR 0 +#define INIT_GPIO_REQUEST 1 +#define INIT_TERMINATOR 2 +#define INIT_PRODUCT_NAME 7 + +#define CMD_RESET_INTERFACE 0x40 +#define CMD_SEND_FIRMWARE 0x95 +#define CMD_ENABLE_INTERFACE 0xb4 +#define CMD_ACK_GPIO_CMD 0xa1 + +struct dchid_init_block_hdr { + u16 type; + u16 length; +} __packed; + +#define MAX_GPIO_NAME 32 + +struct dchid_gpio_request { + u16 unk; + u16 id; + char name[MAX_GPIO_NAME]; +} __packed; + +struct dchid_gpio_cmd { + u8 type; + u8 iface; + u8 gpio; + u8 unk; + u8 cmd; +} __packed; + +struct dchid_gpio_ack { + u8 type; + u32 retcode; + u8 cmd[]; +} __packed; + +#define STM_REPORT_ID 0x10 +#define STM_REPORT_SERIAL 0x11 +#define STM_REPORT_KEYBTYPE 0x14 + +struct dchid_stm_id { + u8 unk; + u16 vendor_id; + u16 product_id; + u16 version_number; + u8 unk2; + u8 unk3; + u8 keyboard_type; + u8 serial_length; + /* Serial follows, but we grab it with a different report. */ +} __packed; + +#define FW_MAGIC 0x46444948 +#define FW_VER 1 + +struct fw_header { + u32 magic; + u32 version; + u32 hdr_length; + u32 data_length; + u32 iface_offset; +} __packed; + +struct dchid_work { + struct work_struct work; + struct dchid_iface *iface; + + struct dchid_hdr hdr; + u8 data[]; +}; + +struct dchid_iface { + struct dockchannel_hid *dchid; + struct hid_device *hid; + struct workqueue_struct *wq; + + bool creating; + struct work_struct create_work; + + int index; + const char *name; + const struct device_node *of_node; + + uint8_t tx_seq; + bool deferred; + bool starting; + bool open; + struct completion ready; + + void *hid_desc; + size_t hid_desc_len; + + struct gpio_desc *gpio; + char gpio_name[MAX_GPIO_NAME]; + int gpio_id; + + struct mutex out_mutex; + u32 out_flags; + int out_report; + u32 retcode; + void *resp_buf; + size_t resp_size; + struct completion out_complete; + + u32 keyboard_layout_id; +}; + +struct dockchannel_hid { + struct device *dev; + struct dockchannel *dc; + struct device_link *helper_link; + + bool id_ready; + struct dchid_stm_id device_id; + char serial[64]; + + struct dchid_iface *comm; + struct dchid_iface *ifaces[MAX_INTERFACES]; + + u8 pkt_buf[MAX_PKT_SIZE]; + + /* Workqueue to asynchronously create HID devices */ + struct workqueue_struct *new_iface_wq; +}; + +static ssize_t apple_layout_id_show(struct device *dev, + struct device_attribute *attr, + char *buf) +{ + struct hid_device *hdev = to_hid_device(dev); + struct dchid_iface *iface = hdev->driver_data; + + return scnprintf(buf, PAGE_SIZE, "%d\n", iface->keyboard_layout_id); +} + +static DEVICE_ATTR_RO(apple_layout_id); + +static struct dchid_iface * +dchid_get_interface(struct dockchannel_hid *dchid, int index, const char *name) +{ + struct dchid_iface *iface; + + if (index >= MAX_INTERFACES) { + dev_err(dchid->dev, "Interface index %d out of range\n", index); + return NULL; + } + + if (dchid->ifaces[index]) + return dchid->ifaces[index]; + + iface = devm_kzalloc(dchid->dev, sizeof(struct dchid_iface), GFP_KERNEL); + if (!iface) + return NULL; + + iface->index = index; + iface->name = devm_kstrdup(dchid->dev, name, GFP_KERNEL); + iface->dchid = dchid; + iface->out_report= -1; + init_completion(&iface->out_complete); + init_completion(&iface->ready); + mutex_init(&iface->out_mutex); + iface->wq = alloc_ordered_workqueue("dchid-%s", WQ_MEM_RECLAIM, iface->name); + if (!iface->wq) + return NULL; + + /* Comm is not a HID subdevice */ + if (!strcmp(name, "comm")) { + dchid->ifaces[index] = iface; + return iface; + } + + iface->of_node = of_get_child_by_name(dchid->dev->of_node, name); + if (!iface->of_node) { + dev_warn(dchid->dev, "No OF node for subdevice %s, ignoring.", name); + return NULL; + } + + dchid->ifaces[index] = iface; + return iface; +} + +static u32 dchid_checksum(void *p, size_t length) +{ + u32 sum = 0; + + while (length >= 4) { + sum += get_unaligned_le32(p); + p += 4; + length -= 4; + } + + WARN_ON_ONCE(length); + return sum; +} + +static int dchid_send(struct dchid_iface *iface, u32 flags, void *msg, size_t size) +{ + u32 checksum = 0xffffffff; + size_t wsize = round_down(size, 4); + size_t tsize = size - wsize; + int ret; + struct { + struct dchid_hdr hdr; + struct dchid_subhdr sub; + } __packed h; + + memset(&h, 0, sizeof(h)); + h.hdr.hdr_len = sizeof(h.hdr); + h.hdr.channel = DCHID_CHANNEL_CMD; + h.hdr.length = round_up(size, 4) + sizeof(h.sub); + h.hdr.seq = iface->tx_seq; + h.hdr.iface = iface->index; + h.sub.flags = flags; + h.sub.length = size; + + ret = dockchannel_send(iface->dchid->dc, &h, sizeof(h)); + if (ret < 0) + return ret; + checksum -= dchid_checksum(&h, sizeof(h)); + + ret = dockchannel_send(iface->dchid->dc, msg, wsize); + if (ret < 0) + return ret; + checksum -= dchid_checksum(msg, wsize); + + if (tsize) { + u8 tail[4] = {0, 0, 0, 0}; + + memcpy(tail, msg + wsize, tsize); + ret = dockchannel_send(iface->dchid->dc, tail, sizeof(tail)); + if (ret < 0) + return ret; + checksum -= dchid_checksum(tail, sizeof(tail)); + } + + ret = dockchannel_send(iface->dchid->dc, &checksum, sizeof(checksum)); + if (ret < 0) + return ret; + + return 0; +} + +static int dchid_cmd(struct dchid_iface *iface, u32 type, u32 req, + void *data, size_t size, void *resp_buf, size_t resp_size) +{ + int ret; + int report_id = *(u8*)data; + + mutex_lock(&iface->out_mutex); + + WARN_ON(iface->out_report != -1); + iface->out_report = report_id; + iface->out_flags = FIELD_PREP(FLAGS_GROUP, type) | FIELD_PREP(FLAGS_REQ, req); + iface->resp_buf = resp_buf; + iface->resp_size = resp_size; + reinit_completion(&iface->out_complete); + + ret = dchid_send(iface, iface->out_flags, data, size); + if (ret < 0) + goto done; + + if (!wait_for_completion_timeout(&iface->out_complete, msecs_to_jiffies(COMMAND_TIMEOUT_MS))) { + dev_err(iface->dchid->dev, "output report 0x%x to iface %d (%s) timed out\n", + report_id, iface->index, iface->name); + ret = -ETIMEDOUT; + goto done; + } + + ret = iface->resp_size; + if (iface->retcode) { + dev_err(iface->dchid->dev, + "output report 0x%x to iface %d (%s) failed with err 0x%x\n", + report_id, iface->index, iface->name, iface->retcode); + ret = -EIO; + } + +done: + iface->tx_seq++; + iface->out_report = -1; + iface->out_flags = 0; + iface->resp_buf = NULL; + iface->resp_size = 0; + mutex_unlock(&iface->out_mutex); + return ret; +} + +static int dchid_comm_cmd(struct dockchannel_hid *dchid, void *cmd, size_t size) +{ + return dchid_cmd(dchid->comm, HID_FEATURE_REPORT, REQ_SET_REPORT, cmd, size, NULL, 0); +} + +static int dchid_enable_interface(struct dchid_iface *iface) +{ + u8 msg[] = { CMD_ENABLE_INTERFACE, iface->index }; + + return dchid_comm_cmd(iface->dchid, msg, sizeof(msg)); +} + +static int dchid_reset_interface(struct dchid_iface *iface, int state) +{ + u8 msg[] = { CMD_RESET_INTERFACE, 1, iface->index, state }; + + return dchid_comm_cmd(iface->dchid, msg, sizeof(msg)); +} + +static int dchid_send_firmware(struct dchid_iface *iface, void *firmware, size_t size) +{ + struct { + u8 cmd; + u8 unk1; + u8 unk2; + u8 iface; + u64 addr; + u32 size; + } __packed msg = { + .cmd = CMD_SEND_FIRMWARE, + .unk1 = 2, + .unk2 = 0, + .iface = iface->index, + .size = size, + }; + dma_addr_t addr; + void *buf = dmam_alloc_coherent(iface->dchid->dev, size, &addr, GFP_KERNEL); + + if (IS_ERR_OR_NULL(buf)) + return buf ? PTR_ERR(buf) : -ENOMEM; + + msg.addr = addr; + memcpy(buf, firmware, size); + wmb(); + + return dchid_comm_cmd(iface->dchid, &msg, sizeof(msg)); +} + +static int dchid_get_firmware(struct dchid_iface *iface, void **firmware, size_t *size) +{ + int ret; + const char *fw_name; + const struct firmware *fw; + struct fw_header *hdr; + u8 *fw_data; + + ret = of_property_read_string(iface->of_node, "firmware-name", &fw_name); + if (ret) { + /* Firmware is only for some devices */ + *firmware = NULL; + *size = 0; + return 0; + } + + ret = request_firmware(&fw, fw_name, iface->dchid->dev); + if (ret) + return ret; + + hdr = (struct fw_header *)fw->data; + + if (hdr->magic != FW_MAGIC || hdr->version != FW_VER || + hdr->hdr_length < sizeof(*hdr) || hdr->hdr_length > fw->size || + (hdr->hdr_length + (size_t)hdr->data_length) > fw->size || + hdr->iface_offset >= hdr->data_length) { + dev_warn(iface->dchid->dev, "%s: invalid firmware header\n", + fw_name); + ret = -EINVAL; + goto done; + } + + fw_data = devm_kmemdup(iface->dchid->dev, fw->data + hdr->hdr_length, + hdr->data_length, GFP_KERNEL); + if (!fw_data) { + ret = -ENOMEM; + goto done; + } + + if (hdr->iface_offset) + fw_data[hdr->iface_offset] = iface->index; + + *firmware = fw_data; + *size = hdr->data_length; + +done: + release_firmware(fw); + return ret; +} + +static int dchid_request_gpio(struct dchid_iface *iface) +{ + char prop_name[MAX_GPIO_NAME + 16]; + + if (iface->gpio) + return 0; + + dev_info(iface->dchid->dev, "Requesting GPIO %s#%d: %s\n", + iface->name, iface->gpio_id, iface->gpio_name); + + snprintf(prop_name, sizeof(prop_name), "apple,%s", iface->gpio_name); + + iface->gpio = devm_gpiod_get_index(iface->dchid->dev, prop_name, 0, GPIOD_OUT_LOW); + + if (IS_ERR_OR_NULL(iface->gpio)) { + dev_err(iface->dchid->dev, "Failed to request GPIO %s-gpios\n", prop_name); + iface->gpio = NULL; + return -1; + } + + return 0; +} + +static int dchid_start_interface(struct dchid_iface *iface) +{ + void *fw; + size_t size; + int ret; + + if (iface->starting) { + dev_warn(iface->dchid->dev, "Interface %s is already starting", iface->name); + return -EINPROGRESS; + } + + dev_info(iface->dchid->dev, "Starting interface %s\n", iface->name); + + iface->starting = true; + + /* Look to see if we need firmware */ + ret = dchid_get_firmware(iface, &fw, &size); + if (ret < 0) + goto err; + + /* If we need a GPIO, make sure we have it. */ + if (iface->gpio_id) { + ret = dchid_request_gpio(iface); + if (ret < 0) + goto err; + } + + /* Only multi-touch has firmware */ + if (fw && size) { + + /* Send firmware to the device */ + dev_info(iface->dchid->dev, "Sending firmware for %s\n", iface->name); + ret = dchid_send_firmware(iface, fw, size); + if (ret < 0) { + dev_err(iface->dchid->dev, "Failed to send %s firmwareS", iface->name); + goto err; + } + + /* After loading firmware, multi-touch needs a reset */ + dev_info(iface->dchid->dev, "Resetting %s\n", iface->name); + dchid_reset_interface(iface, 0); + dchid_reset_interface(iface, 2); + } + + return 0; + +err: + iface->starting = false; + return ret; +} + +static int dchid_start(struct hid_device *hdev) +{ + struct dchid_iface *iface = hdev->driver_data; + + if (iface->keyboard_layout_id) { + int ret = device_create_file(&hdev->dev, &dev_attr_apple_layout_id); + if (ret) { + dev_warn(iface->dchid->dev, "Failed to create apple_layout_id: %d", ret); + iface->keyboard_layout_id = 0; + } + } + + return 0; +}; + +static void dchid_stop(struct hid_device *hdev) +{ + struct dchid_iface *iface = hdev->driver_data; + + if (iface->keyboard_layout_id) + device_remove_file(&hdev->dev, &dev_attr_apple_layout_id); +} + +static int dchid_open(struct hid_device *hdev) +{ + struct dchid_iface *iface = hdev->driver_data; + int ret; + + if (!completion_done(&iface->ready)) { + ret = dchid_start_interface(iface); + if (ret < 0) + return ret; + + if (!wait_for_completion_timeout(&iface->ready, msecs_to_jiffies(START_TIMEOUT_MS))) { + dev_err(iface->dchid->dev, "iface %s start timed out\n", iface->name); + return -ETIMEDOUT; + } + } + + iface->open = true; + return 0; +} + +static void dchid_close(struct hid_device *hdev) +{ + struct dchid_iface *iface = hdev->driver_data; + + iface->open = false; +} + +static int dchid_parse(struct hid_device *hdev) +{ + struct dchid_iface *iface = hdev->driver_data; + + return hid_parse_report(hdev, iface->hid_desc, iface->hid_desc_len); +} + +/* Note: buf excludes report number! For ease of fetching strings/etc. */ +static int dchid_get_report_cmd(struct dchid_iface *iface, u8 reportnum, void *buf, size_t len) +{ + int ret = dchid_cmd(iface, HID_FEATURE_REPORT, REQ_GET_REPORT, &reportnum, 1, buf, len); + + return ret <= 0 ? ret : ret - 1; +} + +/* Note: buf includes report number! */ +static int dchid_set_report(struct dchid_iface *iface, void *buf, size_t len) +{ + return dchid_cmd(iface, HID_OUTPUT_REPORT, REQ_SET_REPORT, buf, len, NULL, 0); +} + +static int dchid_raw_request(struct hid_device *hdev, + unsigned char reportnum, __u8 *buf, size_t len, + unsigned char rtype, int reqtype) +{ + struct dchid_iface *iface = hdev->driver_data; + + switch (reqtype) { + case HID_REQ_GET_REPORT: + buf[0] = reportnum; + return dchid_cmd(iface, rtype, REQ_GET_REPORT, &reportnum, 1, buf + 1, len - 1); + case HID_REQ_SET_REPORT: + return dchid_set_report(iface, buf, len); + default: + return -EIO; + } + + return 0; +} + +static struct hid_ll_driver dchid_ll = { + .start = &dchid_start, + .stop = &dchid_stop, + .open = &dchid_open, + .close = &dchid_close, + .parse = &dchid_parse, + .raw_request = &dchid_raw_request, +}; + +static void dchid_create_interface_work(struct work_struct *ws) +{ + struct dchid_iface *iface = container_of(ws, struct dchid_iface, create_work); + struct dockchannel_hid *dchid = iface->dchid; + struct hid_device *hid; + int ret; + + if (iface->hid) { + dev_warn(dchid->dev, "Interface %s already created!\n", + iface->name); + return; + } + + dev_info(dchid->dev, "New interface %s\n", iface->name); + + /* Start the interface. This is not the entire init process, as firmware is loaded later on device open. */ + ret = dchid_enable_interface(iface); + if (ret < 0) { + dev_warn(dchid->dev, "Failed to enable %s: %d\n", iface->name, ret); + return; + } + + iface->deferred = false; + + hid = hid_allocate_device(); + if (IS_ERR(hid)) + return; + + snprintf(hid->name, sizeof(hid->name), "Apple MTP %s", iface->name); + snprintf(hid->phys, sizeof(hid->phys), "%s.%d (%s)", + dev_name(dchid->dev), iface->index, iface->name); + strscpy(hid->uniq, dchid->serial, sizeof(hid->uniq)); + + hid->ll_driver = &dchid_ll; + hid->bus = BUS_HOST; + hid->vendor = dchid->device_id.vendor_id; + hid->product = dchid->device_id.product_id; + hid->version = dchid->device_id.version_number; + hid->type = HID_TYPE_OTHER; + if (!strcmp(iface->name, "multi-touch")) { + hid->type = HID_TYPE_SPI_MOUSE; + } else if (!strcmp(iface->name, "keyboard")) { + u32 country_code = 0; + + hid->type = HID_TYPE_SPI_KEYBOARD; + + /* + * We have to get the country code from the device tree, since the + * device provides no reliable way to get this info. + */ + if (!of_property_read_u32(iface->of_node, "hid-country-code", &country_code)) + hid->country = country_code; + + of_property_read_u32(iface->of_node, "apple,keyboard-layout-id", + &iface->keyboard_layout_id); + } + + hid->dev.parent = iface->dchid->dev; + hid->driver_data = iface; + + iface->hid = hid; + + ret = hid_add_device(hid); + if (ret < 0) { + iface->hid = NULL; + hid_destroy_device(hid); + dev_warn(iface->dchid->dev, "Failed to register hid device %s", iface->name); + } +} + +static int dchid_create_interface(struct dchid_iface *iface) +{ + if (iface->creating) + return -EBUSY; + + iface->creating = true; + INIT_WORK(&iface->create_work, dchid_create_interface_work); + return queue_work(iface->dchid->new_iface_wq, &iface->create_work); +} + +static void dchid_handle_descriptor(struct dchid_iface *iface, void *hid_desc, size_t desc_len) +{ + if (iface->hid) { + dev_warn(iface->dchid->dev, "Tried to initialize already started interface %s!\n", + iface->name); + return; + } + + iface->hid_desc = devm_kmemdup(iface->dchid->dev, hid_desc, desc_len, GFP_KERNEL); + if (!iface->hid_desc) + return; + + iface->hid_desc_len = desc_len; +} + +static void dchid_handle_ready(struct dockchannel_hid *dchid, void *data, size_t length) +{ + struct dchid_iface *iface; + u8 *pkt = data; + u8 index; + int i, ret; + + if (length < 2) { + dev_err(dchid->dev, "Bad length for ready message: %zu\n", length); + return; + } + + index = pkt[1]; + + if (index >= MAX_INTERFACES) { + dev_err(dchid->dev, "Got ready notification for bad iface %d\n", index); + return; + } + + iface = dchid->ifaces[index]; + if (!iface) { + dev_err(dchid->dev, "Got ready notification for unknown iface %d\n", index); + return; + } + + dev_info(dchid->dev, "Interface %s is now ready\n", iface->name); + complete_all(&iface->ready); + + /* When STM is ready, grab global device info */ + if (!strcmp(iface->name, "stm")) { + ret = dchid_get_report_cmd(iface, STM_REPORT_ID, &dchid->device_id, + sizeof(dchid->device_id)); + if (ret < sizeof(dchid->device_id)) { + dev_warn(iface->dchid->dev, "Failed to get device ID from STM!\n"); + /* Fake it and keep going. Things might still work... */ + memset(&dchid->device_id, 0, sizeof(dchid->device_id)); + dchid->device_id.vendor_id = HOST_VENDOR_ID_APPLE; + } + ret = dchid_get_report_cmd(iface, STM_REPORT_SERIAL, dchid->serial, + sizeof(dchid->serial) - 1); + if (ret < 0) { + dev_warn(iface->dchid->dev, "Failed to get serial from STM!\n"); + dchid->serial[0] = 0; + } + + dchid->id_ready = true; + for (i = 0; i < MAX_INTERFACES; i++) { + if (!dchid->ifaces[i] || !dchid->ifaces[i]->deferred) + continue; + dchid_create_interface(dchid->ifaces[i]); + } + } +} + +static void dchid_handle_init(struct dockchannel_hid *dchid, void *data, size_t length) +{ + struct dchid_init_hdr *hdr = data; + struct dchid_iface *iface; + struct dchid_init_block_hdr *blk; + + if (length < sizeof(*hdr)) + return; + + iface = dchid_get_interface(dchid, hdr->iface, hdr->name); + if (!iface) + return; + + data += sizeof(*hdr); + length -= sizeof(*hdr); + + while (length >= sizeof(*blk)) { + blk = data; + data += sizeof(*blk); + length -= sizeof(*blk); + + if (blk->length > length) + break; + + switch (blk->type) { + case INIT_HID_DESCRIPTOR: + dchid_handle_descriptor(iface, data, blk->length); + break; + + case INIT_GPIO_REQUEST: { + struct dchid_gpio_request *req = data; + + if (sizeof(*req) > length) + break; + + if (iface->gpio_id) { + dev_err(dchid->dev, + "Cannot request more than one GPIO per interface!\n"); + break; + } + + strscpy(iface->gpio_name, req->name, MAX_GPIO_NAME); + iface->gpio_id = req->id; + break; + } + + case INIT_TERMINATOR: + break; + + case INIT_PRODUCT_NAME: { + char *product = data; + + if (product[blk->length - 1] != 0) { + dev_warn(dchid->dev, "Unterminated product name for %s\n", + iface->name); + } else { + dev_info(dchid->dev, "Product name for %s: %s\n", + iface->name, product); + } + break; + } + + default: + dev_warn(dchid->dev, "Unknown init packet %d for %s\n", + blk->type, iface->name); + break; + } + + data += blk->length; + length -= blk->length; + + if (blk->type == INIT_TERMINATOR) + break; + } + + if (hdr->more_packets) + return; + + /* We need to enable STM first, since it'll give us the device IDs */ + if (iface->dchid->id_ready || !strcmp(iface->name, "stm")) { + dchid_create_interface(iface); + } else { + iface->deferred = true; + } +} + +static void dchid_handle_gpio(struct dockchannel_hid *dchid, void *data, size_t length) +{ + struct dchid_gpio_cmd *cmd = data; + struct dchid_iface *iface; + u32 retcode = 0xe000f00d; /* Give it a random Apple-style error code */ + struct dchid_gpio_ack *ack; + + if (length < sizeof(*cmd)) + return; + + if (cmd->iface >= MAX_INTERFACES || !(iface = dchid->ifaces[cmd->iface])) { + dev_err(dchid->dev, "Got GPIO command for bad inteface %d\n", cmd->iface); + goto err; + } + + if (dchid_request_gpio(iface) < 0) + goto err; + + if (!iface->gpio || cmd->gpio != iface->gpio_id) { + dev_err(dchid->dev, "Got GPIO command for bad GPIO %s#%d\n", + iface->name, cmd->gpio); + goto err; + } + + dev_info(dchid->dev, "GPIO command: %s#%d: %d\n", iface->name, cmd->gpio, cmd->cmd); + + switch (cmd->cmd) { + case 3: + /* Pulse. */ + gpiod_set_value_cansleep(iface->gpio, 1); + msleep(10); /* Random guess... */ + gpiod_set_value_cansleep(iface->gpio, 0); + retcode = 0; + break; + default: + dev_err(dchid->dev, "Unknown GPIO command %d\n", cmd->cmd ); + break; + } + +err: + /* Ack it */ + ack = kzalloc(sizeof(*ack) + length, GFP_KERNEL); + if (!ack) + return; + + ack->type = CMD_ACK_GPIO_CMD; + ack->retcode = retcode; + memcpy(ack->cmd, data, length); + + if (dchid_comm_cmd(dchid, ack, sizeof(*ack) + length) < 0) + dev_err(dchid->dev, "Failed to ACK GPIO command\n"); + + kfree(ack); +} + +static void dchid_handle_event(struct dockchannel_hid *dchid, void *data, size_t length) +{ + u8 *p = data; + switch (*p) { + case EVENT_INIT: + dchid_handle_init(dchid, data, length); + break; + case EVENT_READY: + dchid_handle_ready(dchid, data, length); + break; + case EVENT_GPIO_CMD: + dchid_handle_gpio(dchid, data, length); + break; + } +} + +static void dchid_handle_report(struct dchid_iface *iface, void *data, size_t length) +{ + struct dockchannel_hid *dchid = iface->dchid; + + if (!iface->hid) { + dev_warn(dchid->dev, "Report received but %s is not initialized!\n", iface->name); + return; + } + + if (!iface->open) + return; + + hid_input_report(iface->hid, HID_INPUT_REPORT, data, length, 1); +} + +static void dchid_packet_work(struct work_struct *ws) +{ + struct dchid_work *work = container_of(ws, struct dchid_work, work); + struct dchid_subhdr *shdr = (void *)work->data; + struct dockchannel_hid *dchid = work->iface->dchid; + int type = FIELD_GET(FLAGS_GROUP, shdr->flags); + u8 *payload = work->data + sizeof(*shdr); + + if (shdr->length + sizeof(*shdr) > work->hdr.length) { + dev_err(dchid->dev, "Bad sub header length (%d > %zu)\n", + shdr->length, work->hdr.length - sizeof(*shdr)); + return; + } + + switch (type) { + case HID_INPUT_REPORT: + if (work->hdr.iface == IFACE_COMM) + dchid_handle_event(dchid, payload, shdr->length); + else + dchid_handle_report(work->iface, payload, shdr->length); + break; + default: + dev_err(dchid->dev, "Received unknown packet type %d\n", type); + break; + } + + kfree(work); +} + +static void dchid_handle_ack(struct dchid_iface *iface, struct dchid_hdr *hdr, void *data) +{ + struct dchid_subhdr *shdr = (void *)data; + u8 *payload = data + sizeof(*shdr); + + if (shdr->length + sizeof(*shdr) > hdr->length) { + dev_err(iface->dchid->dev, "Bad sub header length (%d > %ld)\n", + shdr->length, hdr->length - sizeof(*shdr)); + return; + } + if (shdr->flags != iface->out_flags) { + dev_err(iface->dchid->dev, + "Received unexpected flags 0x%x on ACK channel (expFected 0x%x)\n", + shdr->flags, iface->out_flags); + return; + } + + if (shdr->length < 1) { + dev_err(iface->dchid->dev, "Received length 0 output report ack\n"); + return; + } + if (iface->tx_seq != hdr->seq) { + dev_err(iface->dchid->dev, "Received ACK with bad seq (expected %d, got %d)\n", + iface->tx_seq, hdr->seq); + return; + } + if (iface->out_report != payload[0]) { + dev_err(iface->dchid->dev, "Received ACK with bad report (expected %d, got %d\n", + iface->out_report, payload[0]); + return; + } + + if (iface->resp_buf && iface->resp_size) + memcpy(iface->resp_buf, payload + 1, min((size_t)shdr->length - 1, iface->resp_size)); + + iface->resp_size = shdr->length; + iface->out_report = -1; + iface->retcode = shdr->retcode; + complete(&iface->out_complete); +} + +static void dchid_handle_packet(void *cookie, size_t avail) +{ + struct dockchannel_hid *dchid = cookie; + struct dchid_hdr hdr; + struct dchid_work *work; + struct dchid_iface *iface; + u32 checksum; + + if (dockchannel_recv(dchid->dc, &hdr, sizeof(hdr)) != sizeof(hdr)) { + dev_err(dchid->dev, "Read failed (header)\n"); + return; + } + + if (hdr.hdr_len != sizeof(hdr)) { + dev_err(dchid->dev, "Bad header length %d\n", hdr.hdr_len); + goto done; + } + + if (dockchannel_recv(dchid->dc, dchid->pkt_buf, hdr.length + 4) != (hdr.length + 4)) { + dev_err(dchid->dev, "Read failed (body)\n"); + goto done; + } + + checksum = dchid_checksum(&hdr, sizeof(hdr)); + checksum += dchid_checksum(dchid->pkt_buf, hdr.length + 4); + + if (checksum != 0xffffffff) { + dev_err(dchid->dev, "Checksum mismatch (iface %d): 0x%08x != 0xffffffff\n", + hdr.iface, checksum); + goto done; + } + + + if (hdr.iface >= MAX_INTERFACES) { + dev_err(dchid->dev, "Bad iface %d\n", hdr.iface); + } + + iface = dchid->ifaces[hdr.iface]; + + if (!iface) { + dev_err(dchid->dev, "Received packet for uninitialized iface %d\n", hdr.iface); + goto done; + } + + switch (hdr.channel) { + case DCHID_CHANNEL_CMD: + dchid_handle_ack(iface, &hdr, dchid->pkt_buf); + goto done; + case DCHID_CHANNEL_REPORT: + break; + default: + dev_warn(dchid->dev, "Unknown channel 0x%x, treating as report...\n", + hdr.channel); + break; + } + + work = kzalloc(sizeof(*work) + hdr.length, GFP_KERNEL); + if (!work) + return; + + work->hdr = hdr; + work->iface = iface; + memcpy(work->data, dchid->pkt_buf, hdr.length); + INIT_WORK(&work->work, dchid_packet_work); + + queue_work(iface->wq, &work->work); + +done: + dockchannel_await(dchid->dc, dchid_handle_packet, dchid, sizeof(struct dchid_hdr)); +} + +static int dockchannel_hid_probe(struct platform_device *pdev) +{ + struct device *dev = &pdev->dev; + struct dockchannel_hid *dchid; + struct device_node *child, *helper; + struct platform_device *helper_pdev; + struct property *prop; + int ret; + + ret = dma_set_mask_and_coherent(&pdev->dev, DMA_BIT_MASK(64)); + if (ret) + return ret; + + dchid = devm_kzalloc(dev, sizeof(*dchid), GFP_KERNEL); + if (!dchid) { + return -ENOMEM; + } + + dchid->dev = dev; + + /* + * First make sure all the GPIOs are available, in cased we need to defer. + * This is necessary because MTP will request them by name later, and by then + * it's too late to defer the probe. + */ + + for_each_child_of_node(dev->of_node, child) { + for_each_property_of_node(child, prop) { + size_t len = strlen(prop->name); + struct gpio_desc *gpio; + + if (len < 12 || strncmp("apple,", prop->name, 6) || + strcmp("-gpios", prop->name + len - 6)) + continue; + + gpio = fwnode_gpiod_get_index(&child->fwnode, prop->name, 0, GPIOD_ASIS, + prop->name); + if (IS_ERR_OR_NULL(gpio)) { + if (PTR_ERR(gpio) == -EPROBE_DEFER) { + of_node_put(child); + return -EPROBE_DEFER; + } + } else { + gpiod_put(gpio); + } + } + } + + /* + * Make sure we also have the MTP coprocessor available, and + * defer probe if the helper hasn't probed yet. + */ + helper = of_parse_phandle(dev->of_node, "apple,helper-cpu", 0); + if (!helper) { + dev_err(dev, "Missing apple,helper-cpu property"); + return -EINVAL; + } + + helper_pdev = of_find_device_by_node(helper); + of_node_put(helper); + if (!helper_pdev) { + dev_err(dev, "Failed to find helper device"); + return -EINVAL; + } + + dchid->helper_link = device_link_add(dev, &helper_pdev->dev, + DL_FLAG_AUTOREMOVE_CONSUMER); + put_device(&helper_pdev->dev); + if (!dchid->helper_link) { + dev_err(dev, "Failed to link to helper device"); + return -EINVAL; + } + + if (dchid->helper_link->supplier->links.status != DL_DEV_DRIVER_BOUND) + return -EPROBE_DEFER; + + /* Now it is safe to begin initializing */ + dchid->dc = dockchannel_init(pdev); + if (IS_ERR_OR_NULL(dchid->dc)) { + return PTR_ERR(dchid->dc); + } + dchid->new_iface_wq = alloc_workqueue("dchid-new", WQ_MEM_RECLAIM, 0); + if (!dchid->new_iface_wq) + return -ENOMEM; + + dchid->comm = dchid_get_interface(dchid, IFACE_COMM, "comm"); + if (!dchid->comm) { + dev_err(dchid->dev, "Failed to initialize comm interface"); + return -EIO; + } + + dev_info(dchid->dev, "Initialized, awaiting packets\n"); + dockchannel_await(dchid->dc, dchid_handle_packet, dchid, sizeof(struct dchid_hdr)); + + return 0; +} + +static void dockchannel_hid_remove(struct platform_device *pdev) +{ + BUG_ON(1); +} + +static const struct of_device_id dockchannel_hid_of_match[] = { + { .compatible = "apple,dockchannel-hid" }, + {}, +}; +MODULE_DEVICE_TABLE(of, dockchannel_hid_of_match); +MODULE_FIRMWARE("apple/tpmtfw-*.bin"); + +static struct platform_driver dockchannel_hid_driver = { + .driver = { + .name = "dockchannel-hid", + .of_match_table = dockchannel_hid_of_match, + }, + .probe = dockchannel_hid_probe, + .remove = dockchannel_hid_remove, +}; +module_platform_driver(dockchannel_hid_driver); + +MODULE_DESCRIPTION("Apple DockChannel HID transport driver"); +MODULE_AUTHOR("Hector Martin "); +MODULE_LICENSE("Dual MIT/GPL"); diff --git a/drivers/hid/hid-apple.c b/drivers/hid/hid-apple.c index d900dd05c335..1217b15c3e2b 100644 --- a/drivers/hid/hid-apple.c +++ b/drivers/hid/hid-apple.c @@ -276,6 +276,50 @@ static const struct apple_key_translation apple_fn_keys[] = { { } }; +static const struct apple_key_translation apple_fn_keys_spi[] = { + { KEY_BACKSPACE, KEY_DELETE }, + { KEY_ENTER, KEY_INSERT }, + { KEY_F1, KEY_BRIGHTNESSDOWN, APPLE_FLAG_FKEY }, + { KEY_F2, KEY_BRIGHTNESSUP, APPLE_FLAG_FKEY }, + { KEY_F3, KEY_SCALE, APPLE_FLAG_FKEY }, + { KEY_F4, KEY_SEARCH, APPLE_FLAG_FKEY }, + { KEY_F5, KEY_RECORD, APPLE_FLAG_FKEY }, + { KEY_F6, KEY_SLEEP, APPLE_FLAG_FKEY }, + { KEY_F7, KEY_PREVIOUSSONG, APPLE_FLAG_FKEY }, + { KEY_F8, KEY_PLAYPAUSE, APPLE_FLAG_FKEY }, + { KEY_F9, KEY_NEXTSONG, APPLE_FLAG_FKEY }, + { KEY_F10, KEY_MUTE, APPLE_FLAG_FKEY }, + { KEY_F11, KEY_VOLUMEDOWN, APPLE_FLAG_FKEY }, + { KEY_F12, KEY_VOLUMEUP, APPLE_FLAG_FKEY }, + { KEY_UP, KEY_PAGEUP }, + { KEY_DOWN, KEY_PAGEDOWN }, + { KEY_LEFT, KEY_HOME }, + { KEY_RIGHT, KEY_END }, + { } +}; + +static const struct apple_key_translation apple_fn_keys_mbp13[] = { + { KEY_BACKSPACE, KEY_DELETE }, + { KEY_ENTER, KEY_INSERT }, + { KEY_UP, KEY_PAGEUP }, + { KEY_DOWN, KEY_PAGEDOWN }, + { KEY_LEFT, KEY_HOME }, + { KEY_RIGHT, KEY_END }, + { KEY_1, KEY_F1 }, + { KEY_2, KEY_F2 }, + { KEY_3, KEY_F3 }, + { KEY_4, KEY_F4 }, + { KEY_5, KEY_F5 }, + { KEY_6, KEY_F6 }, + { KEY_7, KEY_F7 }, + { KEY_8, KEY_F8 }, + { KEY_9, KEY_F9 }, + { KEY_0, KEY_F10 }, + { KEY_MINUS, KEY_F11 }, + { KEY_EQUAL, KEY_F12 }, + { } +}; + static const struct apple_key_translation powerbook_fn_keys[] = { { KEY_BACKSPACE, KEY_DELETE }, { KEY_F1, KEY_BRIGHTNESSDOWN, APPLE_FLAG_FKEY }, @@ -486,6 +530,7 @@ static int hidinput_apple_event(struct hid_device *hid, struct input_dev *input, table = apple2021_fn_keys; else if (hid->product == USB_DEVICE_ID_APPLE_WELLSPRINGT2_J132 || hid->product == USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680 || + hid->product == USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680_ALT || hid->product == USB_DEVICE_ID_APPLE_WELLSPRINGT2_J213) table = macbookpro_no_esc_fn_keys; else if (hid->product == USB_DEVICE_ID_APPLE_WELLSPRINGT2_J214K || @@ -498,6 +543,16 @@ static int hidinput_apple_event(struct hid_device *hid, struct input_dev *input, else if (hid->product >= USB_DEVICE_ID_APPLE_WELLSPRING4_ANSI && hid->product <= USB_DEVICE_ID_APPLE_WELLSPRING4A_JIS) table = macbookair_fn_keys; + else if (hid->bus == BUS_HOST || hid->bus == BUS_SPI) + switch (hid->product) { + case SPI_DEVICE_ID_APPLE_MACBOOK_PRO13_2020: + case HOST_DEVICE_ID_APPLE_MACBOOK_PRO13_2022: + table = apple_fn_keys_mbp13; + break; + default: + table = apple_fn_keys_spi; + break; + } else if (hid->product < 0x21d || hid->product >= 0x300) table = powerbook_fn_keys; else @@ -677,6 +732,8 @@ static void apple_setup_input(struct input_dev *input) /* Enable all needed keys */ apple_setup_key_translation(input, apple_fn_keys); + apple_setup_key_translation(input, apple_fn_keys_spi); + apple_setup_key_translation(input, apple_fn_keys_mbp13); apple_setup_key_translation(input, powerbook_fn_keys); apple_setup_key_translation(input, powerbook_numlock_keys); apple_setup_key_translation(input, apple_iso_keyboard); @@ -910,6 +967,13 @@ static int apple_probe(struct hid_device *hdev, struct apple_sc *asc; int ret; + if ((id->bus == BUS_SPI || id->bus == BUS_HOST) && id->vendor == SPI_VENDOR_ID_APPLE && + hdev->type != HID_TYPE_SPI_KEYBOARD) + return -ENODEV; + + if (quirks & APPLE_IGNORE_MOUSE && hdev->type == HID_TYPE_USBMOUSE) + return -ENODEV; + asc = devm_kzalloc(&hdev->dev, sizeof(*asc), GFP_KERNEL); if (asc == NULL) { hid_err(hdev, "can't alloc apple descriptor\n"); @@ -1127,21 +1191,28 @@ static const struct hid_device_id apple_devices[] = { { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRING9_JIS), .driver_data = APPLE_HAS_FN | APPLE_RDESC_JIS }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J140K), - .driver_data = APPLE_HAS_FN | APPLE_BACKLIGHT_CTL | APPLE_ISO_TILDE_QUIRK }, + .driver_data = APPLE_HAS_FN | APPLE_BACKLIGHT_CTL | APPLE_ISO_TILDE_QUIRK | + APPLE_IGNORE_MOUSE }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J132), - .driver_data = APPLE_HAS_FN | APPLE_BACKLIGHT_CTL | APPLE_ISO_TILDE_QUIRK }, + .driver_data = APPLE_HAS_FN | APPLE_BACKLIGHT_CTL | APPLE_ISO_TILDE_QUIRK | + APPLE_IGNORE_MOUSE }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680), - .driver_data = APPLE_HAS_FN | APPLE_BACKLIGHT_CTL | APPLE_ISO_TILDE_QUIRK }, + .driver_data = APPLE_HAS_FN | APPLE_BACKLIGHT_CTL | APPLE_ISO_TILDE_QUIRK | + APPLE_IGNORE_MOUSE }, + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680_ALT), + .driver_data = APPLE_HAS_FN | APPLE_BACKLIGHT_CTL | APPLE_ISO_TILDE_QUIRK | + APPLE_IGNORE_MOUSE }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J213), - .driver_data = APPLE_HAS_FN | APPLE_BACKLIGHT_CTL | APPLE_ISO_TILDE_QUIRK }, + .driver_data = APPLE_HAS_FN | APPLE_BACKLIGHT_CTL | APPLE_ISO_TILDE_QUIRK | + APPLE_IGNORE_MOUSE }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J214K), - .driver_data = APPLE_HAS_FN | APPLE_ISO_TILDE_QUIRK }, + .driver_data = APPLE_HAS_FN | APPLE_ISO_TILDE_QUIRK | APPLE_IGNORE_MOUSE }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J223), - .driver_data = APPLE_HAS_FN | APPLE_ISO_TILDE_QUIRK }, + .driver_data = APPLE_HAS_FN | APPLE_ISO_TILDE_QUIRK | APPLE_IGNORE_MOUSE }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J230K), - .driver_data = APPLE_HAS_FN | APPLE_ISO_TILDE_QUIRK }, + .driver_data = APPLE_HAS_FN | APPLE_ISO_TILDE_QUIRK | APPLE_IGNORE_MOUSE }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J152F), - .driver_data = APPLE_HAS_FN | APPLE_ISO_TILDE_QUIRK }, + .driver_data = APPLE_HAS_FN | APPLE_ISO_TILDE_QUIRK | APPLE_IGNORE_MOUSE }, { HID_BLUETOOTH_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_ALU_WIRELESS_2009_ANSI), .driver_data = APPLE_NUMLOCK_EMULATION | APPLE_HAS_FN }, { HID_BLUETOOTH_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_ALU_WIRELESS_2009_ISO), @@ -1169,6 +1240,10 @@ static const struct hid_device_id apple_devices[] = { .driver_data = APPLE_HAS_FN | APPLE_ISO_TILDE_QUIRK | APPLE_RDESC_BATTERY }, { HID_BLUETOOTH_DEVICE(BT_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_MAGIC_KEYBOARD_NUMPAD_2021), .driver_data = APPLE_HAS_FN | APPLE_ISO_TILDE_QUIRK }, + { HID_SPI_DEVICE(SPI_VENDOR_ID_APPLE, HID_ANY_ID), + .driver_data = APPLE_HAS_FN | APPLE_ISO_TILDE_QUIRK }, + { HID_DEVICE(BUS_HOST, HID_GROUP_ANY, HOST_VENDOR_ID_APPLE, HID_ANY_ID), + .driver_data = APPLE_HAS_FN | APPLE_ISO_TILDE_QUIRK }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_TOUCHBAR_BACKLIGHT), .driver_data = APPLE_MAGIC_BACKLIGHT }, diff --git a/drivers/hid/hid-appletb-bl.c b/drivers/hid/hid-appletb-bl.c new file mode 100644 index 000000000000..bad2aead8780 --- /dev/null +++ b/drivers/hid/hid-appletb-bl.c @@ -0,0 +1,204 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Apple Touch Bar Backlight Driver + * + * Copyright (c) 2017-2018 Ronald Tschalär + * Copyright (c) 2022-2023 Kerem Karabay + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include + +#include "hid-ids.h" + +#define APPLETB_BL_ON 1 +#define APPLETB_BL_DIM 3 +#define APPLETB_BL_OFF 4 + +#define HID_UP_APPLEVENDOR_TB_BL 0xff120000 + +#define HID_VD_APPLE_TB_BRIGHTNESS 0xff120001 +#define HID_USAGE_AUX1 0xff120020 +#define HID_USAGE_BRIGHTNESS 0xff120021 + +static int appletb_bl_def_brightness = 2; +module_param_named(brightness, appletb_bl_def_brightness, int, 0444); +MODULE_PARM_DESC(brightness, "Default brightness:\n" + " 0 - Touchbar is off\n" + " 1 - Dim brightness\n" + " [2] - Full brightness"); + +struct appletb_bl { + struct hid_field *aux1_field, *brightness_field; + struct backlight_device *bdev; + + bool full_on; +}; + +static const u8 appletb_bl_brightness_map[] = { + APPLETB_BL_OFF, + APPLETB_BL_DIM, + APPLETB_BL_ON, +}; + +static int appletb_bl_set_brightness(struct appletb_bl *bl, u8 brightness) +{ + struct hid_report *report = bl->brightness_field->report; + struct hid_device *hdev = report->device; + int ret; + + ret = hid_set_field(bl->aux1_field, 0, 1); + if (ret) { + hid_err(hdev, "Failed to set auxiliary field (%pe)\n", ERR_PTR(ret)); + return ret; + } + + ret = hid_set_field(bl->brightness_field, 0, brightness); + if (ret) { + hid_err(hdev, "Failed to set brightness field (%pe)\n", ERR_PTR(ret)); + return ret; + } + + if (!bl->full_on) { + ret = hid_hw_power(hdev, PM_HINT_FULLON); + if (ret < 0) { + hid_err(hdev, "Device didn't power on (%pe)\n", ERR_PTR(ret)); + return ret; + } + + bl->full_on = true; + } + + hid_hw_request(hdev, report, HID_REQ_SET_REPORT); + + if (brightness == APPLETB_BL_OFF) { + hid_hw_power(hdev, PM_HINT_NORMAL); + bl->full_on = false; + } + + return 0; +} + +static int appletb_bl_update_status(struct backlight_device *bdev) +{ + struct appletb_bl *bl = bl_get_data(bdev); + u8 brightness; + + if (backlight_is_blank(bdev)) + brightness = APPLETB_BL_OFF; + else + brightness = appletb_bl_brightness_map[backlight_get_brightness(bdev)]; + + return appletb_bl_set_brightness(bl, brightness); +} + +static const struct backlight_ops appletb_bl_backlight_ops = { + .options = BL_CORE_SUSPENDRESUME, + .update_status = appletb_bl_update_status, +}; + +static int appletb_bl_probe(struct hid_device *hdev, const struct hid_device_id *id) +{ + struct hid_field *aux1_field, *brightness_field; + struct backlight_properties bl_props = { 0 }; + struct device *dev = &hdev->dev; + struct appletb_bl *bl; + int ret; + + ret = hid_parse(hdev); + if (ret) + return dev_err_probe(dev, ret, "HID parse failed\n"); + + aux1_field = hid_find_field(hdev, HID_FEATURE_REPORT, + HID_VD_APPLE_TB_BRIGHTNESS, HID_USAGE_AUX1); + + brightness_field = hid_find_field(hdev, HID_FEATURE_REPORT, + HID_VD_APPLE_TB_BRIGHTNESS, HID_USAGE_BRIGHTNESS); + + if (!aux1_field || !brightness_field) + return -ENODEV; + + if (aux1_field->report != brightness_field->report) + return dev_err_probe(dev, -ENODEV, "Encountered unexpected report structure\n"); + + bl = devm_kzalloc(dev, sizeof(*bl), GFP_KERNEL); + if (!bl) + return -ENOMEM; + + ret = hid_hw_start(hdev, HID_CONNECT_DRIVER); + if (ret) + return dev_err_probe(dev, ret, "HID hardware start failed\n"); + + ret = hid_hw_open(hdev); + if (ret) { + dev_err_probe(dev, ret, "HID hardware open failed\n"); + goto stop_hw; + } + + bl->aux1_field = aux1_field; + bl->brightness_field = brightness_field; + + ret = appletb_bl_set_brightness(bl, + appletb_bl_brightness_map[(appletb_bl_def_brightness > 2) ? 2 : appletb_bl_def_brightness]); + + if (ret) { + dev_err_probe(dev, ret, "Failed to set default touch bar brightness to %d\n", + appletb_bl_def_brightness); + goto close_hw; + } + + bl_props.type = BACKLIGHT_RAW; + bl_props.max_brightness = ARRAY_SIZE(appletb_bl_brightness_map) - 1; + + bl->bdev = devm_backlight_device_register(dev, "appletb_backlight", dev, bl, + &appletb_bl_backlight_ops, &bl_props); + if (IS_ERR(bl->bdev)) { + ret = PTR_ERR(bl->bdev); + dev_err_probe(dev, ret, "Failed to register backlight device\n"); + goto close_hw; + } + + hid_set_drvdata(hdev, bl); + + return 0; + +close_hw: + hid_hw_close(hdev); +stop_hw: + hid_hw_stop(hdev); + + return ret; +} + +static void appletb_bl_remove(struct hid_device *hdev) +{ + struct appletb_bl *bl = hid_get_drvdata(hdev); + + appletb_bl_set_brightness(bl, APPLETB_BL_OFF); + + hid_hw_close(hdev); + hid_hw_stop(hdev); +} + +static const struct hid_device_id appletb_bl_hid_ids[] = { + /* MacBook Pro's 2018, 2019, with T2 chip: iBridge DFR Brightness */ + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_TOUCHBAR_BACKLIGHT) }, + { } +}; +MODULE_DEVICE_TABLE(hid, appletb_bl_hid_ids); + +static struct hid_driver appletb_bl_hid_driver = { + .name = "hid-appletb-bl", + .id_table = appletb_bl_hid_ids, + .probe = appletb_bl_probe, + .remove = appletb_bl_remove, +}; +module_hid_driver(appletb_bl_hid_driver); + +MODULE_AUTHOR("Ronald Tschalär"); +MODULE_AUTHOR("Kerem Karabay "); +MODULE_DESCRIPTION("MacBook Pro Touch Bar Backlight driver"); +MODULE_LICENSE("GPL"); diff --git a/drivers/hid/hid-appletb-kbd.c b/drivers/hid/hid-appletb-kbd.c new file mode 100644 index 000000000000..d4b95aa3eecb --- /dev/null +++ b/drivers/hid/hid-appletb-kbd.c @@ -0,0 +1,507 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Apple Touch Bar Keyboard Mode Driver + * + * Copyright (c) 2017-2018 Ronald Tschalär + * Copyright (c) 2022-2023 Kerem Karabay + * Copyright (c) 2024-2025 Aditya Garg + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "hid-ids.h" + +#define APPLETB_KBD_MODE_ESC 0 +#define APPLETB_KBD_MODE_FN 1 +#define APPLETB_KBD_MODE_SPCL 2 +#define APPLETB_KBD_MODE_OFF 3 +#define APPLETB_KBD_MODE_MAX APPLETB_KBD_MODE_OFF + +#define APPLETB_DEVID_KEYBOARD 1 +#define APPLETB_DEVID_TRACKPAD 2 + +#define HID_USAGE_MODE 0x00ff0004 + +static int appletb_tb_def_mode = APPLETB_KBD_MODE_SPCL; +module_param_named(mode, appletb_tb_def_mode, int, 0444); +MODULE_PARM_DESC(mode, "Default touchbar mode:\n" + " 0 - escape key only\n" + " 1 - function-keys\n" + " [2] - special keys"); + +static bool appletb_tb_fn_toggle = true; +module_param_named(fntoggle, appletb_tb_fn_toggle, bool, 0644); +MODULE_PARM_DESC(fntoggle, "Switch between Fn and media controls on pressing Fn key"); + +static bool appletb_tb_autodim = true; +module_param_named(autodim, appletb_tb_autodim, bool, 0644); +MODULE_PARM_DESC(autodim, "Automatically dim and turn off the Touch Bar after some time"); + +static int appletb_tb_dim_timeout = 60; +module_param_named(dim_timeout, appletb_tb_dim_timeout, int, 0644); +MODULE_PARM_DESC(dim_timeout, "Dim timeout in sec"); + +static int appletb_tb_idle_timeout = 15; +module_param_named(idle_timeout, appletb_tb_idle_timeout, int, 0644); +MODULE_PARM_DESC(idle_timeout, "Idle timeout in sec"); + +struct appletb_kbd { + struct hid_field *mode_field; + struct input_handler inp_handler; + struct input_handle kbd_handle; + struct input_handle tpd_handle; + struct backlight_device *backlight_dev; + struct timer_list inactivity_timer; + bool has_dimmed; + bool has_turned_off; + u8 saved_mode; + u8 current_mode; +}; + +static const struct key_entry appletb_kbd_keymap[] = { + { KE_KEY, KEY_ESC, { KEY_ESC } }, + { KE_KEY, KEY_F1, { KEY_BRIGHTNESSDOWN } }, + { KE_KEY, KEY_F2, { KEY_BRIGHTNESSUP } }, + { KE_KEY, KEY_F3, { KEY_RESERVED } }, + { KE_KEY, KEY_F4, { KEY_RESERVED } }, + { KE_KEY, KEY_F5, { KEY_KBDILLUMDOWN } }, + { KE_KEY, KEY_F6, { KEY_KBDILLUMUP } }, + { KE_KEY, KEY_F7, { KEY_PREVIOUSSONG } }, + { KE_KEY, KEY_F8, { KEY_PLAYPAUSE } }, + { KE_KEY, KEY_F9, { KEY_NEXTSONG } }, + { KE_KEY, KEY_F10, { KEY_MUTE } }, + { KE_KEY, KEY_F11, { KEY_VOLUMEDOWN } }, + { KE_KEY, KEY_F12, { KEY_VOLUMEUP } }, + { KE_END, 0 } +}; + +static int appletb_kbd_set_mode(struct appletb_kbd *kbd, u8 mode) +{ + struct hid_report *report = kbd->mode_field->report; + struct hid_device *hdev = report->device; + int ret; + + ret = hid_hw_power(hdev, PM_HINT_FULLON); + if (ret) { + hid_err(hdev, "Device didn't resume (%pe)\n", ERR_PTR(ret)); + return ret; + } + + ret = hid_set_field(kbd->mode_field, 0, mode); + if (ret) { + hid_err(hdev, "Failed to set mode field to %u (%pe)\n", mode, ERR_PTR(ret)); + goto power_normal; + } + + hid_hw_request(hdev, report, HID_REQ_SET_REPORT); + + kbd->current_mode = mode; + +power_normal: + hid_hw_power(hdev, PM_HINT_NORMAL); + + return ret; +} + +static ssize_t mode_show(struct device *dev, + struct device_attribute *attr, char *buf) +{ + struct appletb_kbd *kbd = dev_get_drvdata(dev); + + return sysfs_emit(buf, "%d\n", kbd->current_mode); +} + +static ssize_t mode_store(struct device *dev, + struct device_attribute *attr, + const char *buf, size_t size) +{ + struct appletb_kbd *kbd = dev_get_drvdata(dev); + u8 mode; + int ret; + + ret = kstrtou8(buf, 0, &mode); + if (ret) + return ret; + + if (mode > APPLETB_KBD_MODE_MAX) + return -EINVAL; + + ret = appletb_kbd_set_mode(kbd, mode); + + return ret < 0 ? ret : size; +} +static DEVICE_ATTR_RW(mode); + +static struct attribute *appletb_kbd_attrs[] = { + &dev_attr_mode.attr, + NULL +}; +ATTRIBUTE_GROUPS(appletb_kbd); + +static int appletb_tb_key_to_slot(unsigned int code) +{ + switch (code) { + case KEY_ESC: + return 0; + case KEY_F1 ... KEY_F10: + return code - KEY_F1 + 1; + case KEY_F11 ... KEY_F12: + return code - KEY_F11 + 11; + + default: + return -EINVAL; + } +} + +static void appletb_inactivity_timer(struct timer_list *t) +{ + struct appletb_kbd *kbd = from_timer(kbd, t, inactivity_timer); + + if (kbd->backlight_dev && appletb_tb_autodim) { + if (!kbd->has_dimmed) { + backlight_device_set_brightness(kbd->backlight_dev, 1); + kbd->has_dimmed = true; + mod_timer(&kbd->inactivity_timer, jiffies + msecs_to_jiffies(appletb_tb_idle_timeout * 1000)); + } else if (!kbd->has_turned_off) { + backlight_device_set_brightness(kbd->backlight_dev, 0); + kbd->has_turned_off = true; + } + } +} + +static void reset_inactivity_timer(struct appletb_kbd *kbd) +{ + if (kbd->backlight_dev && appletb_tb_autodim) { + if (kbd->has_dimmed || kbd->has_turned_off) { + backlight_device_set_brightness(kbd->backlight_dev, 2); + kbd->has_dimmed = false; + kbd->has_turned_off = false; + } + mod_timer(&kbd->inactivity_timer, jiffies + msecs_to_jiffies(appletb_tb_dim_timeout * 1000)); + } +} + +static int appletb_kbd_hid_event(struct hid_device *hdev, struct hid_field *field, + struct hid_usage *usage, __s32 value) +{ + struct appletb_kbd *kbd = hid_get_drvdata(hdev); + struct key_entry *translation; + struct input_dev *input; + int slot; + + if ((usage->hid & HID_USAGE_PAGE) != HID_UP_KEYBOARD || usage->type != EV_KEY) + return 0; + + input = field->hidinput->input; + + /* + * Skip non-touch-bar keys. + * + * Either the touch bar itself or usbhid generate a slew of key-down + * events for all the meta keys. None of which we're at all interested + * in. + */ + slot = appletb_tb_key_to_slot(usage->code); + if (slot < 0) + return 0; + + reset_inactivity_timer(kbd); + + translation = sparse_keymap_entry_from_scancode(input, usage->code); + + if (translation && kbd->current_mode == APPLETB_KBD_MODE_SPCL) { + input_event(input, usage->type, translation->keycode, value); + + return 1; + } + + return kbd->current_mode == APPLETB_KBD_MODE_OFF; +} + +static void appletb_kbd_inp_event(struct input_handle *handle, unsigned int type, + unsigned int code, int value) +{ + struct appletb_kbd *kbd = handle->private; + + reset_inactivity_timer(kbd); + + if (type == EV_KEY && code == KEY_FN && appletb_tb_fn_toggle && + (kbd->current_mode == APPLETB_KBD_MODE_SPCL || + kbd->current_mode == APPLETB_KBD_MODE_FN)) { + if (value == 1) { + kbd->saved_mode = kbd->current_mode; + appletb_kbd_set_mode(kbd, kbd->current_mode == APPLETB_KBD_MODE_SPCL + ? APPLETB_KBD_MODE_FN : APPLETB_KBD_MODE_SPCL); + } else if (value == 0) { + if (kbd->saved_mode != kbd->current_mode) + appletb_kbd_set_mode(kbd, kbd->saved_mode); + } + } +} + +static int appletb_kbd_inp_connect(struct input_handler *handler, + struct input_dev *dev, + const struct input_device_id *id) +{ + struct appletb_kbd *kbd = handler->private; + struct input_handle *handle; + int rc; + + if (id->driver_info == APPLETB_DEVID_KEYBOARD) { + handle = &kbd->kbd_handle; + handle->name = "tbkbd"; + } else if (id->driver_info == APPLETB_DEVID_TRACKPAD) { + handle = &kbd->tpd_handle; + handle->name = "tbtpd"; + } else { + return -ENOENT; + } + + if (handle->dev) + return -EEXIST; + + handle->open = 0; + handle->dev = input_get_device(dev); + handle->handler = handler; + handle->private = kbd; + + rc = input_register_handle(handle); + if (rc) + goto err_free_dev; + + rc = input_open_device(handle); + if (rc) + goto err_unregister_handle; + + return 0; + + err_unregister_handle: + input_unregister_handle(handle); + err_free_dev: + input_put_device(handle->dev); + handle->dev = NULL; + return rc; +} + +static void appletb_kbd_inp_disconnect(struct input_handle *handle) +{ + input_close_device(handle); + input_unregister_handle(handle); + + input_put_device(handle->dev); + handle->dev = NULL; +} + +static int appletb_kbd_input_configured(struct hid_device *hdev, struct hid_input *hidinput) +{ + int idx; + struct input_dev *input = hidinput->input; + + /* + * Clear various input capabilities that are blindly set by the hid + * driver (usbkbd.c) + */ + memset(input->evbit, 0, sizeof(input->evbit)); + memset(input->keybit, 0, sizeof(input->keybit)); + memset(input->ledbit, 0, sizeof(input->ledbit)); + + __set_bit(EV_REP, input->evbit); + + sparse_keymap_setup(input, appletb_kbd_keymap, NULL); + + for (idx = 0; appletb_kbd_keymap[idx].type != KE_END; idx++) + input_set_capability(input, EV_KEY, appletb_kbd_keymap[idx].code); + + return 0; +} + +static const struct input_device_id appletb_kbd_input_devices[] = { + { + .flags = INPUT_DEVICE_ID_MATCH_BUS | + INPUT_DEVICE_ID_MATCH_VENDOR | + INPUT_DEVICE_ID_MATCH_KEYBIT, + .bustype = BUS_USB, + .vendor = USB_VENDOR_ID_APPLE, + .keybit = { [BIT_WORD(KEY_FN)] = BIT_MASK(KEY_FN) }, + .driver_info = APPLETB_DEVID_KEYBOARD, + }, + { + .flags = INPUT_DEVICE_ID_MATCH_BUS | + INPUT_DEVICE_ID_MATCH_VENDOR | + INPUT_DEVICE_ID_MATCH_KEYBIT, + .bustype = BUS_USB, + .vendor = USB_VENDOR_ID_APPLE, + .keybit = { [BIT_WORD(BTN_TOUCH)] = BIT_MASK(BTN_TOUCH) }, + .driver_info = APPLETB_DEVID_TRACKPAD, + }, + { } +}; + +static bool appletb_kbd_match_internal_device(struct input_handler *handler, + struct input_dev *inp_dev) +{ + struct device *dev = &inp_dev->dev; + + /* in kernel: dev && !is_usb_device(dev) */ + while (dev && !(dev->type && dev->type->name && + !strcmp(dev->type->name, "usb_device"))) + dev = dev->parent; + + /* + * Apple labels all their internal keyboards and trackpads as such, + * instead of maintaining an ever expanding list of product-id's we + * just look at the device's product name. + */ + if (dev) + return !!strstr(to_usb_device(dev)->product, "Internal Keyboard"); + + return false; +} + +static int appletb_kbd_probe(struct hid_device *hdev, const struct hid_device_id *id) +{ + struct appletb_kbd *kbd; + struct device *dev = &hdev->dev; + struct hid_field *mode_field; + int ret; + + ret = hid_parse(hdev); + if (ret) + return dev_err_probe(dev, ret, "HID parse failed\n"); + + mode_field = hid_find_field(hdev, HID_OUTPUT_REPORT, + HID_GD_KEYBOARD, HID_USAGE_MODE); + if (!mode_field) + return -ENODEV; + + kbd = devm_kzalloc(dev, sizeof(*kbd), GFP_KERNEL); + if (!kbd) + return -ENOMEM; + + kbd->mode_field = mode_field; + + ret = hid_hw_start(hdev, HID_CONNECT_HIDINPUT); + if (ret) + return dev_err_probe(dev, ret, "HID hw start failed\n"); + + ret = hid_hw_open(hdev); + if (ret) { + dev_err_probe(dev, ret, "HID hw open failed\n"); + goto stop_hw; + } + + kbd->backlight_dev = backlight_device_get_by_name("appletb_backlight"); + if (!kbd->backlight_dev) { + dev_err_probe(dev, -ENODEV, "Failed to get backlight device\n"); + } else { + backlight_device_set_brightness(kbd->backlight_dev, 2); + timer_setup(&kbd->inactivity_timer, appletb_inactivity_timer, 0); + mod_timer(&kbd->inactivity_timer, jiffies + msecs_to_jiffies(appletb_tb_dim_timeout * 1000)); + } + + kbd->inp_handler.event = appletb_kbd_inp_event; + kbd->inp_handler.connect = appletb_kbd_inp_connect; + kbd->inp_handler.disconnect = appletb_kbd_inp_disconnect; + kbd->inp_handler.name = "appletb"; + kbd->inp_handler.id_table = appletb_kbd_input_devices; + kbd->inp_handler.match = appletb_kbd_match_internal_device; + kbd->inp_handler.private = kbd; + + ret = input_register_handler(&kbd->inp_handler); + if (ret) { + dev_err_probe(dev, ret, "Unable to register keyboard handler\n"); + goto close_hw; + } + + ret = appletb_kbd_set_mode(kbd, appletb_tb_def_mode); + if (ret) { + dev_err_probe(dev, ret, "Failed to set touchbar mode\n"); + goto close_hw; + } + + hid_set_drvdata(hdev, kbd); + + return 0; + +close_hw: + hid_hw_close(hdev); +stop_hw: + hid_hw_stop(hdev); + return ret; +} + +static void appletb_kbd_remove(struct hid_device *hdev) +{ + struct appletb_kbd *kbd = hid_get_drvdata(hdev); + + appletb_kbd_set_mode(kbd, APPLETB_KBD_MODE_OFF); + + input_unregister_handler(&kbd->inp_handler); + del_timer_sync(&kbd->inactivity_timer); + + hid_hw_close(hdev); + hid_hw_stop(hdev); +} + +#ifdef CONFIG_PM +static int appletb_kbd_suspend(struct hid_device *hdev, pm_message_t msg) +{ + struct appletb_kbd *kbd = hid_get_drvdata(hdev); + + kbd->saved_mode = kbd->current_mode; + appletb_kbd_set_mode(kbd, APPLETB_KBD_MODE_OFF); + + return 0; +} + +static int appletb_kbd_reset_resume(struct hid_device *hdev) +{ + struct appletb_kbd *kbd = hid_get_drvdata(hdev); + + appletb_kbd_set_mode(kbd, kbd->saved_mode); + + return 0; +} +#endif + +static const struct hid_device_id appletb_kbd_hid_ids[] = { + /* MacBook Pro's 2018, 2019, with T2 chip: iBridge Display */ + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_TOUCHBAR_DISPLAY) }, + { } +}; +MODULE_DEVICE_TABLE(hid, appletb_kbd_hid_ids); + +static struct hid_driver appletb_kbd_hid_driver = { + .name = "hid-appletb-kbd", + .id_table = appletb_kbd_hid_ids, + .probe = appletb_kbd_probe, + .remove = appletb_kbd_remove, + .event = appletb_kbd_hid_event, + .input_configured = appletb_kbd_input_configured, +#ifdef CONFIG_PM + .suspend = appletb_kbd_suspend, + .reset_resume = appletb_kbd_reset_resume, +#endif + .driver.dev_groups = appletb_kbd_groups, +}; +module_hid_driver(appletb_kbd_hid_driver); + +/* The backlight driver should be loaded before the keyboard driver is initialised */ +MODULE_SOFTDEP("pre: hid_appletb_bl"); + +MODULE_AUTHOR("Ronald Tschalär"); +MODULE_AUTHOR("Kerem Karabay "); +MODULE_AUTHOR("Aditya Garg "); +MODULE_DESCRIPTION("MacBook Pro Touch Bar Keyboard Mode driver"); +MODULE_LICENSE("GPL"); diff --git a/drivers/hid/hid-core.c b/drivers/hid/hid-core.c index 4497b50799db..a79fd45c7a2c 100644 --- a/drivers/hid/hid-core.c +++ b/drivers/hid/hid-core.c @@ -464,7 +464,10 @@ static int hid_parser_global(struct hid_parser *parser, struct hid_item *item) case HID_GLOBAL_ITEM_TAG_REPORT_SIZE: parser->global.report_size = item_udata(item); - if (parser->global.report_size > 256) { + /* Arbitrary maximum. Some Apple devices have 16384 here. + * This * HID_MAX_USAGES must fit in a signed integer. + */ + if (parser->global.report_size > 16384) { hid_err(parser->device, "invalid report_size %d\n", parser->global.report_size); return -1; @@ -2290,6 +2293,12 @@ int hid_connect(struct hid_device *hdev, unsigned int connect_mask) case BUS_I2C: bus = "I2C"; break; + case BUS_SPI: + bus = "SPI"; + break; + case BUS_HOST: + bus = "HOST"; + break; case BUS_VIRTUAL: bus = "VIRTUAL"; break; diff --git a/drivers/hid/hid-ids.h b/drivers/hid/hid-ids.h index 50cd02b049fc..b716cafc63b1 100644 --- a/drivers/hid/hid-ids.h +++ b/drivers/hid/hid-ids.h @@ -89,6 +89,8 @@ #define USB_VENDOR_ID_APPLE 0x05ac #define BT_VENDOR_ID_APPLE 0x004c +#define SPI_VENDOR_ID_APPLE 0x05ac +#define HOST_VENDOR_ID_APPLE 0x05ac #define USB_DEVICE_ID_APPLE_MIGHTYMOUSE 0x0304 #define USB_DEVICE_ID_APPLE_MAGICMOUSE 0x030d #define USB_DEVICE_ID_APPLE_MAGICMOUSE2 0x0269 @@ -168,14 +170,15 @@ #define USB_DEVICE_ID_APPLE_WELLSPRING9_ANSI 0x0272 #define USB_DEVICE_ID_APPLE_WELLSPRING9_ISO 0x0273 #define USB_DEVICE_ID_APPLE_WELLSPRING9_JIS 0x0274 -#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J140K 0x027a -#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J132 0x027b -#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680 0x027c -#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J213 0x027d -#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J214K 0x027e -#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J223 0x027f -#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J230K 0x0280 -#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J152F 0x0340 +#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J140K 0x027a +#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J132 0x027b +#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680 0x027c +#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680_ALT 0x0278 +#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J213 0x027d +#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J214K 0x027e +#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J223 0x027f +#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J230K 0x0280 +#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J152F 0x0340 #define USB_DEVICE_ID_APPLE_FOUNTAIN_TP_ONLY 0x030a #define USB_DEVICE_ID_APPLE_GEYSER1_TP_ONLY 0x030b #define USB_DEVICE_ID_APPLE_IRCONTROL 0x8240 @@ -189,6 +192,12 @@ #define USB_DEVICE_ID_APPLE_MAGIC_KEYBOARD_NUMPAD_2021 0x029f #define USB_DEVICE_ID_APPLE_TOUCHBAR_BACKLIGHT 0x8102 #define USB_DEVICE_ID_APPLE_TOUCHBAR_DISPLAY 0x8302 +#define SPI_DEVICE_ID_APPLE_MACBOOK_AIR_2020 0x0281 +#define SPI_DEVICE_ID_APPLE_MACBOOK_PRO13_2020 0x0341 +#define SPI_DEVICE_ID_APPLE_MACBOOK_PRO14_2021 0x0342 +#define SPI_DEVICE_ID_APPLE_MACBOOK_PRO16_2021 0x0343 +#define HOST_DEVICE_ID_APPLE_MACBOOK_AIR13_2022 0x0351 +#define HOST_DEVICE_ID_APPLE_MACBOOK_PRO13_2022 0x0354 #define USB_VENDOR_ID_ASETEK 0x2433 #define USB_DEVICE_ID_ASETEK_INVICTA 0xf300 diff --git a/drivers/hid/hid-magicmouse.c b/drivers/hid/hid-magicmouse.c index a76f17158539..4598cbac49db 100644 --- a/drivers/hid/hid-magicmouse.c +++ b/drivers/hid/hid-magicmouse.c @@ -60,8 +60,14 @@ MODULE_PARM_DESC(report_undeciphered, "Report undeciphered multi-touch state fie #define MOUSE_REPORT_ID 0x29 #define MOUSE2_REPORT_ID 0x12 #define DOUBLE_REPORT_ID 0xf7 +#define SPI_REPORT_ID 0x02 +#define SPI_RESET_REPORT_ID 0x60 +#define MTP_REPORT_ID 0x75 +#define SENSOR_DIMENSIONS_REPORT_ID 0xd9 #define USB_BATTERY_TIMEOUT_MS 60000 +#define MAX_CONTACTS 16 + /* These definitions are not precise, but they're close enough. (Bits * 0x03 seem to indicate the aspect ratio of the touch, bits 0x70 seem * to be some kind of bit mask -- 0x20 may be a near-field reading, @@ -112,30 +118,156 @@ MODULE_PARM_DESC(report_undeciphered, "Report undeciphered multi-touch state fie #define TRACKPAD2_RES_Y \ ((TRACKPAD2_MAX_Y - TRACKPAD2_MIN_Y) / (TRACKPAD2_DIMENSION_Y / 100)) +#define J140K_TP_DIMENSION_X (float)12100 +#define J140K_TP_MIN_X -5318 +#define J140K_TP_MAX_X 5787 +#define J140K_TP_RES_X \ + ((J140K_TP_MAX_X - J140K_TP_MIN_X) / (J140K_TP_DIMENSION_X / 100)) +#define J140K_TP_DIMENSION_Y (float)8200 +#define J140K_TP_MIN_Y -157 +#define J140K_TP_MAX_Y 7102 +#define J140K_TP_RES_Y \ + ((J140K_TP_MAX_Y - J140K_TP_MIN_Y) / (J140K_TP_DIMENSION_Y / 100)) + +#define J132_TP_DIMENSION_X (float)13500 +#define J132_TP_MIN_X -6243 +#define J132_TP_MAX_X 6749 +#define J132_TP_RES_X \ + ((J132_TP_MAX_X - J132_TP_MIN_X) / (J132_TP_DIMENSION_X / 100)) +#define J132_TP_DIMENSION_Y (float)8400 +#define J132_TP_MIN_Y -170 +#define J132_TP_MAX_Y 7685 +#define J132_TP_RES_Y \ + ((J132_TP_MAX_Y - J132_TP_MIN_Y) / (J132_TP_DIMENSION_Y / 100)) + +#define J680_TP_DIMENSION_X (float)16000 +#define J680_TP_MIN_X -7456 +#define J680_TP_MAX_X 7976 +#define J680_TP_RES_X \ + ((J680_TP_MAX_X - J680_TP_MIN_X) / (J680_TP_DIMENSION_X / 100)) +#define J680_TP_DIMENSION_Y (float)10000 +#define J680_TP_MIN_Y -163 +#define J680_TP_MAX_Y 9283 +#define J680_TP_RES_Y \ + ((J680_TP_MAX_Y - J680_TP_MIN_Y) / (J680_TP_DIMENSION_Y / 100)) + +#define J680_ALT_TP_DIMENSION_X (float)16000 +#define J680_ALT_TP_MIN_X -7456 +#define J680_ALT_TP_MAX_X 7976 +#define J680_ALT_TP_RES_X \ + ((J680_ALT_TP_MAX_X - J680_ALT_TP_MIN_X) / (J680_ALT_TP_DIMENSION_X / 100)) +#define J680_ALT_TP_DIMENSION_Y (float)10000 +#define J680_ALT_TP_MIN_Y -163 +#define J680_ALT_TP_MAX_Y 9283 +#define J680_ALT_TP_RES_Y \ + ((J680_ALT_TP_MAX_Y - J680_ALT_TP_MIN_Y) / (J680_ALT_TP_DIMENSION_Y / 100)) + +#define J213_TP_DIMENSION_X (float)13500 +#define J213_TP_MIN_X -6243 +#define J213_TP_MAX_X 6749 +#define J213_TP_RES_X \ + ((J213_TP_MAX_X - J213_TP_MIN_X) / (J213_TP_DIMENSION_X / 100)) +#define J213_TP_DIMENSION_Y (float)8400 +#define J213_TP_MIN_Y -170 +#define J213_TP_MAX_Y 7685 +#define J213_TP_RES_Y \ + ((J213_TP_MAX_Y - J213_TP_MIN_Y) / (J213_TP_DIMENSION_Y / 100)) + +#define J214K_TP_DIMENSION_X (float)13200 +#define J214K_TP_MIN_X -6046 +#define J214K_TP_MAX_X 6536 +#define J214K_TP_RES_X \ + ((J214K_TP_MAX_X - J214K_TP_MIN_X) / (J214K_TP_DIMENSION_X / 100)) +#define J214K_TP_DIMENSION_Y (float)8200 +#define J214K_TP_MIN_Y -164 +#define J214K_TP_MAX_Y 7439 +#define J214K_TP_RES_Y \ + ((J214K_TP_MAX_Y - J214K_TP_MIN_Y) / (J214K_TP_DIMENSION_Y / 100)) + +#define J223_TP_DIMENSION_X (float)13200 +#define J223_TP_MIN_X -6046 +#define J223_TP_MAX_X 6536 +#define J223_TP_RES_X \ + ((J223_TP_MAX_X - J223_TP_MIN_X) / (J223_TP_DIMENSION_X / 100)) +#define J223_TP_DIMENSION_Y (float)8200 +#define J223_TP_MIN_Y -164 +#define J223_TP_MAX_Y 7439 +#define J223_TP_RES_Y \ + ((J223_TP_MAX_Y - J223_TP_MIN_Y) / (J223_TP_DIMENSION_Y / 100)) + +#define J230K_TP_DIMENSION_X (float)12100 +#define J230K_TP_MIN_X -5318 +#define J230K_TP_MAX_X 5787 +#define J230K_TP_RES_X \ + ((J230K_TP_MAX_X - J230K_TP_MIN_X) / (J230K_TP_DIMENSION_X / 100)) +#define J230K_TP_DIMENSION_Y (float)8200 +#define J230K_TP_MIN_Y -157 +#define J230K_TP_MAX_Y 7102 +#define J230K_TP_RES_Y \ + ((J230K_TP_MAX_Y - J230K_TP_MIN_Y) / (J230K_TP_DIMENSION_Y / 100)) + +#define J152F_TP_DIMENSION_X (float)16000 +#define J152F_TP_MIN_X -7456 +#define J152F_TP_MAX_X 7976 +#define J152F_TP_RES_X \ + ((J152F_TP_MAX_X - J152F_TP_MIN_X) / (J152F_TP_DIMENSION_X / 100)) +#define J152F_TP_DIMENSION_Y (float)10000 +#define J152F_TP_MIN_Y -163 +#define J152F_TP_MAX_Y 9283 +#define J152F_TP_RES_Y \ + ((J152F_TP_MAX_Y - J152F_TP_MIN_Y) / (J152F_TP_DIMENSION_Y / 100)) + +/* These are fallback values, since the real values will be queried from the device. */ +#define J314_TP_DIMENSION_X (float)13000 +#define J314_TP_MIN_X -5900 +#define J314_TP_MAX_X 6500 +#define J314_TP_RES_X \ + ((J314_TP_MAX_X - J314_TP_MIN_X) / (J314_TP_DIMENSION_X / 100)) +#define J314_TP_DIMENSION_Y (float)8100 +#define J314_TP_MIN_Y -200 +#define J314_TP_MAX_Y 7400 +#define J314_TP_RES_Y \ + ((J314_TP_MAX_Y - J314_TP_MIN_Y) / (J314_TP_DIMENSION_Y / 100)) + +#define T2_TOUCHPAD_ENTRY(model) \ + { USB_DEVICE_ID_APPLE_WELLSPRINGT2_##model, model##_TP_MIN_X, model##_TP_MIN_Y, \ +model##_TP_MAX_X, model##_TP_MAX_Y, model##_TP_RES_X, model##_TP_RES_Y } + +#define INTERNAL_TP_MAX_FINGER_ORIENTATION 16384 + +struct magicmouse_input_ops { + int (*raw_event)(struct hid_device *hdev, + struct hid_report *report, u8 *data, int size); + int (*setup_input)(struct input_dev *input, struct hid_device *hdev); +}; + /** * struct magicmouse_sc - Tracks Magic Mouse-specific data. * @input: Input device through which we report events. * @quirks: Currently unused. + * @query_dimensions: Whether to query and update dimensions on first open * @ntouches: Number of touches in most recent touch report. * @scroll_accel: Number of consecutive scroll motions. * @scroll_jiffies: Time of last scroll motion. + * @pos: multi touch position data of the last report. * @touches: Most recent data for a touch, indexed by tracking ID. * @tracking_ids: Mapping of current touch input data to @touches. * @hdev: Pointer to the underlying HID device. * @work: Workqueue to handle initialization retry for quirky devices. * @battery_timer: Timer for obtaining battery level information. + * @input_ops: Input ops based on device type. */ struct magicmouse_sc { struct input_dev *input; unsigned long quirks; + bool query_dimensions; int ntouches; int scroll_accel; unsigned long scroll_jiffies; + struct input_mt_pos pos[MAX_CONTACTS]; struct { - short x; - short y; short scroll_x; short scroll_y; short scroll_x_hr; @@ -143,14 +275,186 @@ struct magicmouse_sc { u8 size; bool scroll_x_active; bool scroll_y_active; - } touches[16]; - int tracking_ids[16]; + } touches[MAX_CONTACTS]; + int tracking_ids[MAX_CONTACTS]; struct hid_device *hdev; struct delayed_work work; struct timer_list battery_timer; + struct magicmouse_input_ops input_ops; }; +static inline int le16_to_int(__le16 x) +{ + return (signed short)le16_to_cpu(x); +} + +static int magicmouse_enable_multitouch(struct hid_device *hdev) +{ + const u8 *feature; + const u8 feature_mt[] = { 0xD7, 0x01 }; + const u8 feature_mt_mouse2[] = { 0xF1, 0x02, 0x01 }; + const u8 feature_mt_trackpad2_usb[] = { 0x02, 0x01 }; + const u8 feature_mt_trackpad2_bt[] = { 0xF1, 0x02, 0x01 }; + u8 *buf; + int ret; + int feature_size; + + switch (hdev->bus) { + case BUS_SPI: + case BUS_HOST: + feature_size = sizeof(feature_mt_trackpad2_usb); + feature = feature_mt_trackpad2_usb; + break; + default: + switch (hdev->product) { + case USB_DEVICE_ID_APPLE_MAGICTRACKPAD2: + case USB_DEVICE_ID_APPLE_MAGICTRACKPAD2_USBC: + switch (hdev->vendor) { + case BT_VENDOR_ID_APPLE: + feature_size = sizeof(feature_mt_trackpad2_bt); + feature = feature_mt_trackpad2_bt; + break; + default: /* USB_VENDOR_ID_APPLE */ + feature_size = sizeof(feature_mt_trackpad2_usb); + feature = feature_mt_trackpad2_usb; + } + break; + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J140K: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J132: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680_ALT: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J213: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J214K: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J223: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J230K: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J152F: + feature_size = sizeof(feature_mt_trackpad2_usb); + feature = feature_mt_trackpad2_usb; + break; + case USB_DEVICE_ID_APPLE_MAGICMOUSE2: + feature_size = sizeof(feature_mt_mouse2); + feature = feature_mt_mouse2; + break; + default: + feature_size = sizeof(feature_mt); + feature = feature_mt; + } + } + + buf = kmemdup(feature, feature_size, GFP_KERNEL); + if (!buf) + return -ENOMEM; + + ret = hid_hw_raw_request(hdev, buf[0], buf, feature_size, + HID_FEATURE_REPORT, HID_REQ_SET_REPORT); + kfree(buf); + return ret; +} + +static void magicmouse_enable_mt_work(struct work_struct *work) +{ + struct magicmouse_sc *msc = + container_of(work, struct magicmouse_sc, work.work); + int ret; + + ret = magicmouse_enable_multitouch(msc->hdev); + if (ret < 0) + hid_err(msc->hdev, "unable to request touch data (%d)\n", ret); +} + +static int magicmouse_open(struct input_dev *dev) +{ + struct hid_device *hdev = input_get_drvdata(dev); + struct magicmouse_sc *msc = hid_get_drvdata(hdev); + int ret; + + ret = hid_hw_open(hdev); + if (ret) + return ret; + + /* + * Some devices repond with 'invalid report id' when feature + * report switching it into multitouch mode is sent to it. + * + * This results in -EIO from the _raw low-level transport callback, + * but there seems to be no other way of switching the mode. + * Thus the super-ugly hacky success check below. + * + * MTP devices do not need this. + */ + if (hdev->bus != BUS_HOST) { + ret = magicmouse_enable_multitouch(hdev); + if (ret == -EIO && hdev->product == USB_DEVICE_ID_APPLE_MAGICMOUSE2) { + schedule_delayed_work(&msc->work, msecs_to_jiffies(500)); + return 0; + } + if (ret < 0) + hid_err(hdev, "unable to request touch data (%d)\n", ret); + } + /* + * MT enable is usually not required after the first time, so don't + * consider it fatal. + */ + + /* + * For Apple Silicon trackpads, we want to query the dimensions on + * device open. This is because doing so requires the firmware, but + * we don't want to force a firmware load until the device is opened + * for the first time. So do that here and update the input properties + * just in time before userspace queries them. + */ + if (msc->query_dimensions) { + struct input_dev *input = msc->input; + u8 buf[32]; + struct { + __le32 width; + __le32 height; + __le16 min_x; + __le16 min_y; + __le16 max_x; + __le16 max_y; + } dim; + uint32_t x_span, y_span; + + ret = hid_hw_raw_request(hdev, SENSOR_DIMENSIONS_REPORT_ID, buf, sizeof(buf), HID_FEATURE_REPORT, HID_REQ_GET_REPORT); + if (ret < (int)(1 + sizeof(dim))) { + hid_err(hdev, "unable to request dimensions (%d)\n", ret); + return ret; + } + + memcpy(&dim, buf + 1, sizeof(dim)); + + /* finger position */ + input_set_abs_params(input, ABS_MT_POSITION_X, + le16_to_int(dim.min_x), le16_to_int(dim.max_x), 0, 0); + /* Y axis is inverted */ + input_set_abs_params(input, ABS_MT_POSITION_Y, + -le16_to_int(dim.max_y), -le16_to_int(dim.min_y), 0, 0); + x_span = le16_to_int(dim.max_x) - le16_to_int(dim.min_x); + y_span = le16_to_int(dim.max_y) - le16_to_int(dim.min_y); + + /* X/Y resolution */ + input_abs_set_res(input, ABS_MT_POSITION_X, 100 * x_span / le32_to_cpu(dim.width) ); + input_abs_set_res(input, ABS_MT_POSITION_Y, 100 * y_span / le32_to_cpu(dim.height) ); + + /* copy info, as input_mt_init_slots() does */ + dev->absinfo[ABS_X] = dev->absinfo[ABS_MT_POSITION_X]; + dev->absinfo[ABS_Y] = dev->absinfo[ABS_MT_POSITION_Y]; + + msc->query_dimensions = false; + } + + return 0; +} + +static void magicmouse_close(struct input_dev *dev) +{ + struct hid_device *hdev = input_get_drvdata(dev); + + hid_hw_close(hdev); +} + static int magicmouse_firm_touch(struct magicmouse_sc *msc) { int touch = -1; @@ -192,7 +496,7 @@ static void magicmouse_emit_buttons(struct magicmouse_sc *msc, int state) } else if (last_state != 0) { state = last_state; } else if ((id = magicmouse_firm_touch(msc)) >= 0) { - int x = msc->touches[id].x; + int x = msc->pos[id].x; if (x < middle_button_start) state = 1; else if (x > middle_button_stop) @@ -255,8 +559,8 @@ static void magicmouse_emit_touch(struct magicmouse_sc *msc, int raw_id, u8 *tda /* Store tracking ID and other fields. */ msc->tracking_ids[raw_id] = id; - msc->touches[id].x = x; - msc->touches[id].y = y; + msc->pos[id].x = x; + msc->pos[id].y = y; msc->touches[id].size = size; /* If requested, emulate a scroll wheel by detecting small @@ -385,6 +689,14 @@ static int magicmouse_raw_event(struct hid_device *hdev, struct hid_report *report, u8 *data, int size) { struct magicmouse_sc *msc = hid_get_drvdata(hdev); + + return msc->input_ops.raw_event(hdev, report, data, size); +} + +static int magicmouse_raw_event_usb(struct hid_device *hdev, + struct hid_report *report, u8 *data, int size) +{ + struct magicmouse_sc *msc = hid_get_drvdata(hdev); struct input_dev *input = msc->input; int x = 0, y = 0, ii, clicks = 0, npoints; @@ -515,6 +827,191 @@ static int magicmouse_raw_event(struct hid_device *hdev, return 1; } +/** + * struct tp_finger - single trackpad finger structure, le16-aligned + * + * @unknown1: unknown + * @unknown2: unknown + * @abs_x: absolute x coordinate + * @abs_y: absolute y coordinate + * @rel_x: relative x coordinate + * @rel_y: relative y coordinate + * @tool_major: tool area, major axis + * @tool_minor: tool area, minor axis + * @orientation: 16384 when point, else 15 bit angle + * @touch_major: touch area, major axis + * @touch_minor: touch area, minor axis + * @unused: zeros + * @pressure: pressure on forcetouch touchpad + * @multi: one finger: varies, more fingers: constant + * @crc16: on last finger: crc over the whole message struct + * (i.e. message header + this struct) minus the last + * @crc16 field; unknown on all other fingers. + */ +struct tp_finger { + __le16 unknown1; + __le16 unknown2; + __le16 abs_x; + __le16 abs_y; + __le16 rel_x; + __le16 rel_y; + __le16 tool_major; + __le16 tool_minor; + __le16 orientation; + __le16 touch_major; + __le16 touch_minor; + __le16 unused[2]; + __le16 pressure; + __le16 multi; +} __attribute__((packed, aligned(2))); + +/** + * vendor trackpad report + * + * @num_fingers: the number of fingers being reported in @fingers + * @buttons: same as HID buttons + */ +struct tp_header { + // HID vendor part, up to 1751 bytes + u8 unknown[22]; + u8 num_fingers; + u8 buttons; + u8 unknown3[14]; +}; + +/** + * standard HID mouse report + * + * @report_id: reportid + * @buttons: HID Usage Buttons 3 1-bit reports + */ +struct tp_mouse_report { + // HID mouse report + u8 report_id; + u8 buttons; + u8 rel_x; + u8 rel_y; + u8 padding[4]; +}; + +static void report_finger_data(struct input_dev *input, int slot, + const struct input_mt_pos *pos, + const struct tp_finger *f) +{ + input_mt_slot(input, slot); + input_mt_report_slot_state(input, MT_TOOL_FINGER, true); + + input_report_abs(input, ABS_MT_TOUCH_MAJOR, + le16_to_int(f->touch_major) << 1); + input_report_abs(input, ABS_MT_TOUCH_MINOR, + le16_to_int(f->touch_minor) << 1); + input_report_abs(input, ABS_MT_WIDTH_MAJOR, + le16_to_int(f->tool_major) << 1); + input_report_abs(input, ABS_MT_WIDTH_MINOR, + le16_to_int(f->tool_minor) << 1); + input_report_abs(input, ABS_MT_ORIENTATION, + INTERNAL_TP_MAX_FINGER_ORIENTATION - le16_to_int(f->orientation)); + input_report_abs(input, ABS_MT_PRESSURE, le16_to_int(f->pressure)); + input_report_abs(input, ABS_MT_POSITION_X, pos->x); + input_report_abs(input, ABS_MT_POSITION_Y, pos->y); +} + +static int magicmouse_raw_event_mtp(struct hid_device *hdev, + struct hid_report *report, u8 *data, int size) +{ + struct magicmouse_sc *msc = hid_get_drvdata(hdev); + struct input_dev *input = msc->input; + struct tp_header *tp_hdr; + struct tp_finger *f; + int i, n; + u32 npoints; + const size_t hdr_sz = sizeof(struct tp_header); + const size_t touch_sz = sizeof(struct tp_finger); + u8 map_contacs[MAX_CONTACTS]; + + // hid_warn(hdev, "%s\n", __func__); + // print_hex_dump_debug("appleft ev: ", DUMP_PREFIX_OFFSET, 16, 1, data, + // size, false); + + /* Expect 46 bytes of prefix, and N * 30 bytes of touch data. */ + if (size < hdr_sz || ((size - hdr_sz) % touch_sz) != 0) + return 0; + + tp_hdr = (struct tp_header *)data; + + npoints = (size - hdr_sz) / touch_sz; + if (npoints < tp_hdr->num_fingers || npoints > MAX_CONTACTS) { + hid_warn(hdev, + "unexpected number of touches (%u) for " + "report\n", + npoints); + return 0; + } + + n = 0; + for (i = 0; i < tp_hdr->num_fingers; i++) { + f = (struct tp_finger *)(data + hdr_sz + i * touch_sz); + if (le16_to_int(f->touch_major) == 0) + continue; + + hid_dbg(hdev, "ev x:%04x y:%04x\n", le16_to_int(f->abs_x), + le16_to_int(f->abs_y)); + msc->pos[n].x = le16_to_int(f->abs_x); + msc->pos[n].y = -le16_to_int(f->abs_y); + map_contacs[n] = i; + n++; + } + + input_mt_assign_slots(input, msc->tracking_ids, msc->pos, n, 0); + + for (i = 0; i < n; i++) { + int idx = map_contacs[i]; + f = (struct tp_finger *)(data + hdr_sz + idx * touch_sz); + report_finger_data(input, msc->tracking_ids[i], &msc->pos[i], f); + } + + input_mt_sync_frame(input); + input_report_key(input, BTN_MOUSE, tp_hdr->buttons & 1); + + input_sync(input); + return 1; +} + +static int magicmouse_raw_event_spi(struct hid_device *hdev, + struct hid_report *report, u8 *data, int size) +{ + struct magicmouse_sc *msc = hid_get_drvdata(hdev); + const size_t hdr_sz = sizeof(struct tp_mouse_report); + + if (!size) + return 0; + + if (data[0] == SPI_RESET_REPORT_ID) { + hid_info(hdev, "Touch controller was reset, re-enabling touch mode\n"); + schedule_delayed_work(&msc->work, msecs_to_jiffies(10)); + return 1; + } + + if (data[0] != TRACKPAD2_USB_REPORT_ID || size < hdr_sz) + return 0; + + return magicmouse_raw_event_mtp(hdev, report, data + hdr_sz, size - hdr_sz); +} + +static int magicmouse_raw_event_t2(struct hid_device *hdev, + struct hid_report *report, u8 *data, int size) +{ + const size_t hdr_sz = sizeof(struct tp_mouse_report); + + if (!size) + return 0; + + if (data[0] != TRACKPAD2_USB_REPORT_ID || size < hdr_sz) + return 0; + + return magicmouse_raw_event_mtp(hdev, report, data + hdr_sz, size - hdr_sz); +} + static int magicmouse_event(struct hid_device *hdev, struct hid_field *field, struct hid_usage *usage, __s32 value) { @@ -532,7 +1029,17 @@ static int magicmouse_event(struct hid_device *hdev, struct hid_field *field, return 0; } -static int magicmouse_setup_input(struct input_dev *input, struct hid_device *hdev) + +static int magicmouse_setup_input(struct input_dev *input, + struct hid_device *hdev) +{ + struct magicmouse_sc *msc = hid_get_drvdata(hdev); + + return msc->input_ops.setup_input(input, hdev); +} + +static int magicmouse_setup_input_usb(struct input_dev *input, + struct hid_device *hdev) { int error; int mt_flags = 0; @@ -610,7 +1117,7 @@ static int magicmouse_setup_input(struct input_dev *input, struct hid_device *hd __set_bit(EV_ABS, input->evbit); - error = input_mt_init_slots(input, 16, mt_flags); + error = input_mt_init_slots(input, MAX_CONTACTS, mt_flags); if (error) return error; input_set_abs_params(input, ABS_MT_TOUCH_MAJOR, 0, 255 << 2, @@ -689,6 +1196,171 @@ static int magicmouse_setup_input(struct input_dev *input, struct hid_device *hd */ __clear_bit(EV_REP, input->evbit); + /* + * This isn't strictly speaking needed for USB, but enabling MT on + * device open is probably more robust than only doing it once on probe + * even if USB devices are not known to suffer from the SPI reset issue. + */ + input->open = magicmouse_open; + input->close = magicmouse_close; + return 0; +} + +struct magicmouse_t2_properties { + u32 id; + int min_x; + int min_y; + int max_x; + int max_y; + int res_x; + int res_y; +}; + +static const struct magicmouse_t2_properties magicmouse_t2_configs[] = { + T2_TOUCHPAD_ENTRY(J140K), + T2_TOUCHPAD_ENTRY(J132), + T2_TOUCHPAD_ENTRY(J680), + T2_TOUCHPAD_ENTRY(J680_ALT), + T2_TOUCHPAD_ENTRY(J213), + T2_TOUCHPAD_ENTRY(J214K), + T2_TOUCHPAD_ENTRY(J223), + T2_TOUCHPAD_ENTRY(J230K), + T2_TOUCHPAD_ENTRY(J152F), +}; + +static int magicmouse_setup_input_int_tpd(struct input_dev *input, + struct hid_device *hdev, int min_x, int min_y, + int max_x, int max_y, int res_x, int res_y, + bool query_dimensions) +{ + int error; + int mt_flags = 0; + struct magicmouse_sc *msc = hid_get_drvdata(hdev); + + __set_bit(INPUT_PROP_BUTTONPAD, input->propbit); + __clear_bit(BTN_0, input->keybit); + __clear_bit(BTN_RIGHT, input->keybit); + __clear_bit(BTN_MIDDLE, input->keybit); + __clear_bit(EV_REL, input->evbit); + __clear_bit(REL_X, input->relbit); + __clear_bit(REL_Y, input->relbit); + + mt_flags = INPUT_MT_POINTER | INPUT_MT_DROP_UNUSED | INPUT_MT_TRACK; + + /* finger touch area */ + input_set_abs_params(input, ABS_MT_TOUCH_MAJOR, 0, 5000, 0, 0); + input_set_abs_params(input, ABS_MT_TOUCH_MINOR, 0, 5000, 0, 0); + + /* finger approach area */ + input_set_abs_params(input, ABS_MT_WIDTH_MAJOR, 0, 5000, 0, 0); + input_set_abs_params(input, ABS_MT_WIDTH_MINOR, 0, 5000, 0, 0); + + /* Note: Touch Y position from the device is inverted relative + * to how pointer motion is reported (and relative to how USB + * HID recommends the coordinates work). This driver keeps + * the origin at the same position, and just uses the additive + * inverse of the reported Y. + */ + + input_set_abs_params(input, ABS_MT_PRESSURE, 0, 6000, 0, 0); + + /* + * This makes libinput recognize this as a PressurePad and + * stop trying to use pressure for touch size. Pressure unit + * seems to be ~grams on these touchpads. + */ + input_abs_set_res(input, ABS_MT_PRESSURE, 1); + + /* finger orientation */ + input_set_abs_params(input, ABS_MT_ORIENTATION, -INTERNAL_TP_MAX_FINGER_ORIENTATION, + INTERNAL_TP_MAX_FINGER_ORIENTATION, 0, 0); + + /* finger position */ + input_set_abs_params(input, ABS_MT_POSITION_X, min_x, max_x, 0, 0); + /* Y axis is inverted */ + input_set_abs_params(input, ABS_MT_POSITION_Y, -max_y, -min_y, 0, 0); + + /* X/Y resolution */ + input_abs_set_res(input, ABS_MT_POSITION_X, res_x); + input_abs_set_res(input, ABS_MT_POSITION_Y, res_y); + + input_set_events_per_packet(input, 60); + + /* touchpad button */ + input_set_capability(input, EV_KEY, BTN_MOUSE); + + /* + * hid-input may mark device as using autorepeat, but the trackpad does + * not actually want it. + */ + __clear_bit(EV_REP, input->evbit); + + error = input_mt_init_slots(input, MAX_CONTACTS, mt_flags); + if (error) + return error; + + /* + * Override the default input->open function to send the MT + * enable every time the device is opened. This ensures it works + * even if we missed a reset event due to the device being closed. + * input->close is overridden for symmetry. + * + * This also takes care of the dimensions query. + */ + input->open = magicmouse_open; + input->close = magicmouse_close; + msc->query_dimensions = query_dimensions; + + return 0; +} + +static int magicmouse_setup_input_mtp(struct input_dev *input, + struct hid_device *hdev) +{ + int ret = magicmouse_setup_input_int_tpd(input, hdev, J314_TP_MIN_X, + J314_TP_MIN_Y, J314_TP_MAX_X, + J314_TP_MAX_Y, J314_TP_RES_X, + J314_TP_RES_Y, true); + if (ret) + return ret; + + return 0; +} + +static int magicmouse_setup_input_spi(struct input_dev *input, + struct hid_device *hdev) +{ + int ret = magicmouse_setup_input_int_tpd(input, hdev, J314_TP_MIN_X, + J314_TP_MIN_Y, J314_TP_MAX_X, + J314_TP_MAX_Y, J314_TP_RES_X, + J314_TP_RES_Y, true); + if (ret) + return ret; + + return 0; +} + +static int magicmouse_setup_input_t2(struct input_dev *input, + struct hid_device *hdev) +{ + int min_x, min_y, max_x, max_y, res_x, res_y; + + for (size_t i = 0; i < ARRAY_SIZE(magicmouse_t2_configs); i++) { + if (magicmouse_t2_configs[i].id == hdev->product) { + min_x = magicmouse_t2_configs[i].min_x; + min_y = magicmouse_t2_configs[i].min_y; + max_x = magicmouse_t2_configs[i].max_x; + max_y = magicmouse_t2_configs[i].max_y; + res_x = magicmouse_t2_configs[i].res_x; + res_y = magicmouse_t2_configs[i].res_y; + } + } + + int ret = magicmouse_setup_input_int_tpd(input, hdev, min_x, min_y, + max_x, max_y, res_x, res_y, false); + if (ret) + return ret; + return 0; } @@ -730,55 +1402,6 @@ static int magicmouse_input_configured(struct hid_device *hdev, return 0; } -static int magicmouse_enable_multitouch(struct hid_device *hdev) -{ - const u8 *feature; - const u8 feature_mt[] = { 0xD7, 0x01 }; - const u8 feature_mt_mouse2[] = { 0xF1, 0x02, 0x01 }; - const u8 feature_mt_trackpad2_usb[] = { 0x02, 0x01 }; - const u8 feature_mt_trackpad2_bt[] = { 0xF1, 0x02, 0x01 }; - u8 *buf; - int ret; - int feature_size; - - if (hdev->product == USB_DEVICE_ID_APPLE_MAGICTRACKPAD2 || - hdev->product == USB_DEVICE_ID_APPLE_MAGICTRACKPAD2_USBC) { - if (hdev->vendor == BT_VENDOR_ID_APPLE) { - feature_size = sizeof(feature_mt_trackpad2_bt); - feature = feature_mt_trackpad2_bt; - } else { /* USB_VENDOR_ID_APPLE */ - feature_size = sizeof(feature_mt_trackpad2_usb); - feature = feature_mt_trackpad2_usb; - } - } else if (hdev->product == USB_DEVICE_ID_APPLE_MAGICMOUSE2) { - feature_size = sizeof(feature_mt_mouse2); - feature = feature_mt_mouse2; - } else { - feature_size = sizeof(feature_mt); - feature = feature_mt; - } - - buf = kmemdup(feature, feature_size, GFP_KERNEL); - if (!buf) - return -ENOMEM; - - ret = hid_hw_raw_request(hdev, buf[0], buf, feature_size, - HID_FEATURE_REPORT, HID_REQ_SET_REPORT); - kfree(buf); - return ret; -} - -static void magicmouse_enable_mt_work(struct work_struct *work) -{ - struct magicmouse_sc *msc = - container_of(work, struct magicmouse_sc, work.work); - int ret; - - ret = magicmouse_enable_multitouch(msc->hdev); - if (ret < 0) - hid_err(msc->hdev, "unable to request touch data (%d)\n", ret); -} - static int magicmouse_fetch_battery(struct hid_device *hdev) { #ifdef CONFIG_HID_BATTERY_STRENGTH @@ -825,12 +1448,62 @@ static int magicmouse_probe(struct hid_device *hdev, struct hid_report *report; int ret; + if ((id->bus == BUS_SPI || id->bus == BUS_HOST) && id->vendor == SPI_VENDOR_ID_APPLE && + hdev->type != HID_TYPE_SPI_MOUSE) + return -ENODEV; + + switch (id->product) { + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J140K: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J132: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680_ALT: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J213: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J214K: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J223: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J230K: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J152F: + if (hdev->type != HID_TYPE_USBMOUSE) + return -ENODEV; + break; + } + msc = devm_kzalloc(&hdev->dev, sizeof(*msc), GFP_KERNEL); if (msc == NULL) { hid_err(hdev, "can't alloc magicmouse descriptor\n"); return -ENOMEM; } + // internal trackpad use a data format use input ops to avoid + // conflicts with the report ID. + switch (id->bus) { + case BUS_HOST: + msc->input_ops.raw_event = magicmouse_raw_event_mtp; + msc->input_ops.setup_input = magicmouse_setup_input_mtp; + break; + case BUS_SPI: + msc->input_ops.raw_event = magicmouse_raw_event_spi; + msc->input_ops.setup_input = magicmouse_setup_input_spi; + break; + default: + switch (id->product) { + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J140K: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J132: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680_ALT: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J213: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J214K: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J223: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J230K: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J152F: + msc->input_ops.raw_event = magicmouse_raw_event_t2; + msc->input_ops.setup_input = magicmouse_setup_input_t2; + break; + default: + msc->input_ops.raw_event = magicmouse_raw_event_usb; + msc->input_ops.setup_input = magicmouse_setup_input_usb; + } + } + msc->scroll_accel = SCROLL_ACCEL_DEFAULT; msc->hdev = hdev; INIT_DEFERRABLE_WORK(&msc->work, magicmouse_enable_mt_work); @@ -868,25 +1541,51 @@ static int magicmouse_probe(struct hid_device *hdev, goto err_stop_hw; } - if (id->product == USB_DEVICE_ID_APPLE_MAGICMOUSE) - report = hid_register_report(hdev, HID_INPUT_REPORT, - MOUSE_REPORT_ID, 0); - else if (id->product == USB_DEVICE_ID_APPLE_MAGICMOUSE2) - report = hid_register_report(hdev, HID_INPUT_REPORT, - MOUSE2_REPORT_ID, 0); - else if (id->product == USB_DEVICE_ID_APPLE_MAGICTRACKPAD2 || - id->product == USB_DEVICE_ID_APPLE_MAGICTRACKPAD2_USBC) { - if (id->vendor == BT_VENDOR_ID_APPLE) - report = hid_register_report(hdev, HID_INPUT_REPORT, - TRACKPAD2_BT_REPORT_ID, 0); - else /* USB_VENDOR_ID_APPLE */ + switch (id->bus) { + case BUS_SPI: + report = hid_register_report(hdev, HID_INPUT_REPORT, SPI_REPORT_ID, 0); + break; + case BUS_HOST: + report = hid_register_report(hdev, HID_INPUT_REPORT, MTP_REPORT_ID, 0); + break; + default: + switch (id->product) { + case USB_DEVICE_ID_APPLE_MAGICMOUSE: + report = hid_register_report(hdev, HID_INPUT_REPORT, MOUSE_REPORT_ID, 0); + break; + case USB_DEVICE_ID_APPLE_MAGICMOUSE2: + report = hid_register_report(hdev, HID_INPUT_REPORT, MOUSE2_REPORT_ID, 0); + break; + case USB_DEVICE_ID_APPLE_MAGICTRACKPAD2: + case USB_DEVICE_ID_APPLE_MAGICTRACKPAD2_USBC: + switch (id->vendor) { + case BT_VENDOR_ID_APPLE: + report = hid_register_report(hdev, HID_INPUT_REPORT, + TRACKPAD2_BT_REPORT_ID, 0); + break; + default: + report = hid_register_report(hdev, HID_INPUT_REPORT, + TRACKPAD2_USB_REPORT_ID, 0); + } + break; + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J140K: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J132: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680_ALT: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J213: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J214K: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J223: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J230K: + case USB_DEVICE_ID_APPLE_WELLSPRINGT2_J152F: report = hid_register_report(hdev, HID_INPUT_REPORT, TRACKPAD2_USB_REPORT_ID, 0); - } else { /* USB_DEVICE_ID_APPLE_MAGICTRACKPAD */ - report = hid_register_report(hdev, HID_INPUT_REPORT, - TRACKPAD_REPORT_ID, 0); - report = hid_register_report(hdev, HID_INPUT_REPORT, - DOUBLE_REPORT_ID, 0); + break; + default: /* USB_DEVICE_ID_APPLE_MAGICTRACKPAD */ + report = hid_register_report(hdev, HID_INPUT_REPORT, + TRACKPAD_REPORT_ID, 0); + report = hid_register_report(hdev, HID_INPUT_REPORT, + DOUBLE_REPORT_ID, 0); + } } if (!report) { @@ -896,21 +1595,14 @@ static int magicmouse_probe(struct hid_device *hdev, } report->size = 6; - /* - * Some devices repond with 'invalid report id' when feature - * report switching it into multitouch mode is sent to it. - * - * This results in -EIO from the _raw low-level transport callback, - * but there seems to be no other way of switching the mode. - * Thus the super-ugly hacky success check below. - */ - ret = magicmouse_enable_multitouch(hdev); - if (ret != -EIO && ret < 0) { - hid_err(hdev, "unable to request touch data (%d)\n", ret); - goto err_stop_hw; - } - if (ret == -EIO && id->product == USB_DEVICE_ID_APPLE_MAGICMOUSE2) { - schedule_delayed_work(&msc->work, msecs_to_jiffies(500)); + /* MTP devices do not need the MT enable, this is handled by the MTP driver */ + if (id->bus == BUS_HOST) + return 0; + + /* SPI devices need to watch for reset events to re-send the MT enable */ + if (id->bus == BUS_SPI) { + report = hid_register_report(hdev, HID_INPUT_REPORT, SPI_RESET_REPORT_ID, 0); + report->size = 2; } return 0; @@ -981,10 +1673,42 @@ static const struct hid_device_id magic_mice[] = { USB_DEVICE_ID_APPLE_MAGICTRACKPAD2_USBC), .driver_data = 0 }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_MAGICTRACKPAD2_USBC), .driver_data = 0 }, + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, + USB_DEVICE_ID_APPLE_WELLSPRINGT2_J140K), .driver_data = 0 }, + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, + USB_DEVICE_ID_APPLE_WELLSPRINGT2_J132), .driver_data = 0 }, + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, + USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680), .driver_data = 0 }, + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, + USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680_ALT), .driver_data = 0 }, + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, + USB_DEVICE_ID_APPLE_WELLSPRINGT2_J213), .driver_data = 0 }, + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, + USB_DEVICE_ID_APPLE_WELLSPRINGT2_J214K), .driver_data = 0 }, + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, + USB_DEVICE_ID_APPLE_WELLSPRINGT2_J223), .driver_data = 0 }, + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, + USB_DEVICE_ID_APPLE_WELLSPRINGT2_J230K), .driver_data = 0 }, + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, + USB_DEVICE_ID_APPLE_WELLSPRINGT2_J152F), .driver_data = 0 }, + { HID_SPI_DEVICE(SPI_VENDOR_ID_APPLE, HID_ANY_ID), + .driver_data = 0 }, + { HID_DEVICE(BUS_HOST, HID_GROUP_ANY, HOST_VENDOR_ID_APPLE, + HID_ANY_ID), .driver_data = 0 }, { } }; MODULE_DEVICE_TABLE(hid, magic_mice); +#ifdef CONFIG_PM +static int magicmouse_reset_resume(struct hid_device *hdev) +{ + if (hdev->bus == BUS_SPI) + return magicmouse_enable_multitouch(hdev); + + return 0; +} +#endif + static struct hid_driver magicmouse_driver = { .name = "magicmouse", .id_table = magic_mice, @@ -995,6 +1719,10 @@ static struct hid_driver magicmouse_driver = { .event = magicmouse_event, .input_mapping = magicmouse_input_mapping, .input_configured = magicmouse_input_configured, +#ifdef CONFIG_PM + .reset_resume = magicmouse_reset_resume, +#endif + }; module_hid_driver(magicmouse_driver); diff --git a/drivers/hid/hid-multitouch.c b/drivers/hid/hid-multitouch.c index e50887a6d22c..c436340331b4 100644 --- a/drivers/hid/hid-multitouch.c +++ b/drivers/hid/hid-multitouch.c @@ -73,6 +73,7 @@ MODULE_LICENSE("GPL"); #define MT_QUIRK_FORCE_MULTI_INPUT BIT(20) #define MT_QUIRK_DISABLE_WAKEUP BIT(21) #define MT_QUIRK_ORIENTATION_INVERT BIT(22) +#define MT_QUIRK_TOUCH_IS_TIPSTATE BIT(23) #define MT_INPUTMODE_TOUCHSCREEN 0x02 #define MT_INPUTMODE_TOUCHPAD 0x03 @@ -153,6 +154,7 @@ struct mt_class { __s32 sn_height; /* Signal/noise ratio for height events */ __s32 sn_pressure; /* Signal/noise ratio for pressure events */ __u8 maxcontacts; + bool is_direct; /* true for touchscreens */ bool is_indirect; /* true for touchpads */ bool export_all_inputs; /* do not ignore mouse, keyboards, etc... */ }; @@ -220,6 +222,7 @@ static void mt_post_parse(struct mt_device *td, struct mt_application *app); #define MT_CLS_GOOGLE 0x0111 #define MT_CLS_RAZER_BLADE_STEALTH 0x0112 #define MT_CLS_SMART_TECH 0x0113 +#define MT_CLS_APPLE_TOUCHBAR 0x0114 #define MT_CLS_SIS 0x0457 #define MT_DEFAULT_MAXCONTACT 10 @@ -405,6 +408,13 @@ static const struct mt_class mt_classes[] = { MT_QUIRK_CONTACT_CNT_ACCURATE | MT_QUIRK_SEPARATE_APP_REPORT, }, + { .name = MT_CLS_APPLE_TOUCHBAR, + .quirks = MT_QUIRK_HOVERING | + MT_QUIRK_TOUCH_IS_TIPSTATE | + MT_QUIRK_SLOT_IS_CONTACTID_MINUS_ONE, + .is_direct = true, + .maxcontacts = 11, + }, { .name = MT_CLS_SIS, .quirks = MT_QUIRK_NOT_SEEN_MEANS_UP | MT_QUIRK_ALWAYS_VALID | @@ -503,9 +513,6 @@ static void mt_feature_mapping(struct hid_device *hdev, if (!td->maxcontacts && field->logical_maximum <= MT_MAX_MAXCONTACT) td->maxcontacts = field->logical_maximum; - if (td->mtclass.maxcontacts) - /* check if the maxcontacts is given by the class */ - td->maxcontacts = td->mtclass.maxcontacts; break; case HID_DG_BUTTONTYPE: @@ -579,13 +586,13 @@ static struct mt_application *mt_allocate_application(struct mt_device *td, mt_application->application = application; INIT_LIST_HEAD(&mt_application->mt_usages); - if (application == HID_DG_TOUCHSCREEN) + if (application == HID_DG_TOUCHSCREEN && !td->mtclass.is_indirect) mt_application->mt_flags |= INPUT_MT_DIRECT; /* * Model touchscreens providing buttons as touchpads. */ - if (application == HID_DG_TOUCHPAD) { + if (application == HID_DG_TOUCHPAD && !td->mtclass.is_direct) { mt_application->mt_flags |= INPUT_MT_POINTER; td->inputmode_value = MT_INPUTMODE_TOUCHPAD; } @@ -649,7 +656,9 @@ static struct mt_report_data *mt_allocate_report_data(struct mt_device *td, if (field->logical == HID_DG_FINGER || td->hdev->group != HID_GROUP_MULTITOUCH_WIN_8) { for (n = 0; n < field->report_count; n++) { - if (field->usage[n].hid == HID_DG_CONTACTID) { + unsigned int hid = field->usage[n].hid; + + if (hid == HID_DG_CONTACTID || hid == HID_DG_TRANSDUCER_INDEX) { rdata->is_mt_collection = true; break; } @@ -821,6 +830,15 @@ static int mt_touch_input_mapping(struct hid_device *hdev, struct hid_input *hi, MT_STORE_FIELD(confidence_state); return 1; + case HID_DG_TOUCH: + /* + * Legacy devices use TIPSWITCH and not TOUCH. + * Let's just ignore this field unless the quirk is set. + */ + if (!(cls->quirks & MT_QUIRK_TOUCH_IS_TIPSTATE)) + return -1; + + fallthrough; case HID_DG_TIPSWITCH: if (field->application != HID_GD_SYSTEM_MULTIAXIS) input_set_capability(hi->input, @@ -828,6 +846,7 @@ static int mt_touch_input_mapping(struct hid_device *hdev, struct hid_input *hi, MT_STORE_FIELD(tip_state); return 1; case HID_DG_CONTACTID: + case HID_DG_TRANSDUCER_INDEX: MT_STORE_FIELD(contactid); app->touches_by_report++; return 1; @@ -883,10 +902,6 @@ static int mt_touch_input_mapping(struct hid_device *hdev, struct hid_input *hi, case HID_DG_CONTACTMAX: /* contact max are global to the report */ return -1; - case HID_DG_TOUCH: - /* Legacy devices use TIPSWITCH and not TOUCH. - * Let's just ignore this field. */ - return -1; } /* let hid-input decide for the others */ return 0; @@ -1314,6 +1329,10 @@ static int mt_touch_input_configured(struct hid_device *hdev, struct input_dev *input = hi->input; int ret; + /* check if the maxcontacts is given by the class */ + if (cls->maxcontacts) + td->maxcontacts = cls->maxcontacts; + if (!td->maxcontacts) td->maxcontacts = MT_DEFAULT_MAXCONTACT; @@ -1321,6 +1340,9 @@ static int mt_touch_input_configured(struct hid_device *hdev, if (td->serial_maybe) mt_post_parse_default_settings(td, app); + if (cls->is_direct) + app->mt_flags |= INPUT_MT_DIRECT; + if (cls->is_indirect) app->mt_flags |= INPUT_MT_POINTER; @@ -1772,6 +1794,15 @@ static int mt_probe(struct hid_device *hdev, const struct hid_device_id *id) } } + ret = hid_parse(hdev); + if (ret != 0) + return ret; + + if (mtclass->name == MT_CLS_APPLE_TOUCHBAR && + !hid_find_field(hdev, HID_INPUT_REPORT, + HID_DG_TOUCHPAD, HID_DG_TRANSDUCER_INDEX)) + return -ENODEV; + td = devm_kzalloc(&hdev->dev, sizeof(struct mt_device), GFP_KERNEL); if (!td) { dev_err(&hdev->dev, "cannot allocate multitouch data\n"); @@ -1819,10 +1850,6 @@ static int mt_probe(struct hid_device *hdev, const struct hid_device_id *id) timer_setup(&td->release_timer, mt_expired_timeout, 0); - ret = hid_parse(hdev); - if (ret != 0) - return ret; - if (mtclass->quirks & MT_QUIRK_FIX_CONST_CONTACT_ID) mt_fix_const_fields(hdev, HID_DG_CONTACTID); @@ -2304,6 +2331,11 @@ static const struct hid_device_id mt_devices[] = { MT_USB_DEVICE(USB_VENDOR_ID_XIROKU, USB_DEVICE_ID_XIROKU_CSR2) }, + /* Apple Touch Bars */ + { .driver_data = MT_CLS_APPLE_TOUCHBAR, + HID_USB_DEVICE(USB_VENDOR_ID_APPLE, + USB_DEVICE_ID_APPLE_TOUCHBAR_DISPLAY) }, + /* Google MT devices */ { .driver_data = MT_CLS_GOOGLE, HID_DEVICE(HID_BUS_ANY, HID_GROUP_ANY, USB_VENDOR_ID_GOOGLE, diff --git a/drivers/hid/hid-quirks.c b/drivers/hid/hid-quirks.c index 5d7a418ccdbe..b7f60ef8917c 100644 --- a/drivers/hid/hid-quirks.c +++ b/drivers/hid/hid-quirks.c @@ -312,6 +312,7 @@ static const struct hid_device_id hid_have_special_driver[] = { { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J140K) }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J132) }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680) }, + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680_ALT) }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J213) }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J214K) }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J223) }, @@ -328,8 +329,6 @@ static const struct hid_device_id hid_have_special_driver[] = { { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_GEYSER1_TP_ONLY) }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_MAGIC_KEYBOARD_2021) }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_MAGIC_KEYBOARD_FINGERPRINT_2021) }, - { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_TOUCHBAR_BACKLIGHT) }, - { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_TOUCHBAR_DISPLAY) }, #endif #if IS_ENABLED(CONFIG_HID_APPLEIR) { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_IRCONTROL) }, @@ -338,6 +337,12 @@ static const struct hid_device_id hid_have_special_driver[] = { { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_IRCONTROL4) }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_IRCONTROL5) }, #endif +#if IS_ENABLED(CONFIG_HID_APPLETB_BL) + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_TOUCHBAR_BACKLIGHT) }, +#endif +#if IS_ENABLED(CONFIG_HID_APPLETB_KBD) + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_TOUCHBAR_DISPLAY) }, +#endif #if IS_ENABLED(CONFIG_HID_ASUS) { HID_I2C_DEVICE(USB_VENDOR_ID_ASUSTEK, USB_DEVICE_ID_ASUSTEK_I2C_KEYBOARD) }, { HID_I2C_DEVICE(USB_VENDOR_ID_ASUSTEK, USB_DEVICE_ID_ASUSTEK_I2C_TOUCHPAD) }, @@ -957,14 +962,6 @@ static const struct hid_device_id hid_mouse_ignore_list[] = { { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRING9_ANSI) }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRING9_ISO) }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRING9_JIS) }, - { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J140K) }, - { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J132) }, - { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680) }, - { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J213) }, - { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J214K) }, - { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J223) }, - { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J230K) }, - { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_WELLSPRINGT2_J152F) }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_FOUNTAIN_TP_ONLY) }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_GEYSER1_TP_ONLY) }, { } diff --git a/drivers/hid/spi-hid/Kconfig b/drivers/hid/spi-hid/Kconfig new file mode 100644 index 000000000000..8e37f0fec28a --- /dev/null +++ b/drivers/hid/spi-hid/Kconfig @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: GPL-2.0-only +menu "SPI HID support" + depends on SPI + +config SPI_HID_APPLE_OF + tristate "HID over SPI transport layer for Apple Silicon SoCs" + default ARCH_APPLE + depends on SPI && INPUT && OF + help + Say Y here if you use Apple Silicon based laptop. The keyboard and + touchpad are HID based devices connected via SPI. + + If unsure, say N. + + This support is also available as a module. If so, the module + will be called spi-hid-apple-of. It will also build/depend on the + module spi-hid-apple. + +endmenu + +config SPI_HID_APPLE_CORE + tristate + default y if SPI_HID_APPLE_OF=y + default m if SPI_HID_APPLE_OF=m + select HID + select CRC16 diff --git a/drivers/hid/spi-hid/Makefile b/drivers/hid/spi-hid/Makefile new file mode 100644 index 000000000000..f276ee12cb94 --- /dev/null +++ b/drivers/hid/spi-hid/Makefile @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for SPI HID tarnsport drivers +# + +obj-$(CONFIG_SPI_HID_APPLE_CORE) += spi-hid-apple.o + +spi-hid-apple-objs = spi-hid-apple-core.o + +obj-$(CONFIG_SPI_HID_APPLE_OF) += spi-hid-apple-of.o diff --git a/drivers/hid/spi-hid/spi-hid-apple-core.c b/drivers/hid/spi-hid/spi-hid-apple-core.c new file mode 100644 index 000000000000..1f8fa64d6d86 --- /dev/null +++ b/drivers/hid/spi-hid/spi-hid-apple-core.c @@ -0,0 +1,1194 @@ +/* + * SPDX-License-Identifier: GPL-2.0 + * + * Apple SPI HID transport driver + * + * Copyright (C) The Asahi Linux Contributors + * + * Based on: drivers/input/applespi.c + * + * MacBook (Pro) SPI keyboard and touchpad driver + * + * Copyright (c) 2015-2018 Federico Lorenzi + * Copyright (c) 2017-2018 Ronald Tschalär + * + */ + +//#define DEBUG 2 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "spi-hid-apple.h" + +#define SPIHID_DEF_WAIT msecs_to_jiffies(1000) + +#define SPIHID_MAX_INPUT_REPORT_SIZE 0x800 + +/* support only keyboard, trackpad and management dev for now */ +#define SPIHID_MAX_DEVICES 3 + +#define SPIHID_DEVICE_ID_MNGT 0x0 +#define SPIHID_DEVICE_ID_KBD 0x1 +#define SPIHID_DEVICE_ID_TP 0x2 +#define SPIHID_DEVICE_ID_INFO 0xd0 + +#define SPIHID_READ_PACKET 0x20 +#define SPIHID_WRITE_PACKET 0x40 + +#define SPIHID_DESC_MAX 512 + +#define SPIHID_SET_LEDS 0x0151 /* caps lock */ + +#define SPI_RW_CHG_DELAY_US 200 /* 'Inter Stage Us'? */ + +static const u8 spi_hid_apple_booted[4] = { 0xa0, 0x80, 0x00, 0x00 }; +static const u8 spi_hid_apple_status_ok[4] = { 0xac, 0x27, 0x68, 0xd5 }; + +struct spihid_interface { + struct hid_device *hid; + u8 *hid_desc; + u32 hid_desc_len; + u32 id; + unsigned country; + u32 max_control_report_len; + u32 max_input_report_len; + u32 max_output_report_len; + u8 name[32]; + u8 reply_buf[SPIHID_DESC_MAX]; + u32 reply_len; + bool ready; +}; + +struct spihid_input_report { + u8 *buf; + u32 length; + u32 offset; + u8 device; + u8 flags; +}; + +struct spihid_apple { + struct spi_device *spidev; + + struct spihid_apple_ops *ops; + + struct spihid_interface mngt; + struct spihid_interface kbd; + struct spihid_interface tp; + + wait_queue_head_t wait; + struct mutex tx_lock; //< protects against concurrent SPI writes + + struct spi_message rx_msg; + struct spi_message tx_msg; + struct spi_transfer rx_transfer; + struct spi_transfer tx_transfer; + struct spi_transfer status_transfer; + + u8 *rx_buf; + u8 *tx_buf; + u8 *status_buf; + + u8 vendor[32]; + u8 product[64]; + u8 serial[32]; + + u32 num_devices; + + u32 vendor_id; + u32 product_id; + u32 version_number; + + u8 msg_id; + + /* fragmented HID report */ + struct spihid_input_report report; + + /* state tracking flags */ + bool status_booted; + +#ifdef IRQ_WAKE_SUPPORT + bool irq_wake_enabled; +#endif +}; + +/** + * struct spihid_msg_hdr - common header of protocol messages. + * + * Each message begins with fixed header, followed by a message-type specific + * payload, and ends with a 16-bit crc. Because of the varying lengths of the + * payload, the crc is defined at the end of each payload struct, rather than + * in this struct. + * + * @unknown0: request type? output, input (0x10), feature, protocol + * @unknown1: maybe report id? + * @unknown2: mostly zero, in info request maybe device num + * @msgid: incremented on each message, rolls over after 255; there is a + * separate counter for each message type. + * @rsplen: response length (the exact nature of this field is quite + * speculative). On a request/write this is often the same as + * @length, though in some cases it has been seen to be much larger + * (e.g. 0x400); on a response/read this the same as on the + * request; for reads that are not responses it is 0. + * @length: length of the remainder of the data in the whole message + * structure (after re-assembly in case of being split over + * multiple spi-packets), minus the trailing crc. The total size + * of a message is therefore @length + 10. + */ + +struct spihid_msg_hdr { + u8 unknown0; + u8 unknown1; + u8 unknown2; + u8 id; + __le16 rsplen; + __le16 length; +}; + +/** + * struct spihid_transfer_packet - a complete spi packet; always 256 bytes. This carries + * the (parts of the) message in the data. But note that this does not + * necessarily contain a complete message, as in some cases (e.g. many + * fingers pressed) the message is split over multiple packets (see the + * @offset, @remain, and @length fields). In general the data parts in + * spihid_transfer_packet's are concatenated until @remaining is 0, and the + * result is an message. + * + * @flags: 0x40 = write (to device), 0x20 = read (from device); note that + * the response to a write still has 0x40. + * @device: 1 = keyboard, 2 = touchpad + * @offset: specifies the offset of this packet's data in the complete + * message; i.e. > 0 indicates this is a continuation packet (in + * the second packet for a message split over multiple packets + * this would then be the same as the @length in the first packet) + * @remain: number of message bytes remaining in subsequents packets (in + * the first packet of a message split over two packets this would + * then be the same as the @length in the second packet) + * @length: length of the valid data in the @data in this packet + * @data: all or part of a message + * @crc16: crc over this whole structure minus this @crc16 field. This + * covers just this packet, even on multi-packet messages (in + * contrast to the crc in the message). + */ +struct spihid_transfer_packet { + u8 flags; + u8 device; + __le16 offset; + __le16 remain; + __le16 length; + u8 data[246]; + __le16 crc16; +}; + +/* + * how HID is mapped onto the protocol is not fully clear. This are the known + * reports/request: + * + * pkt.flags pkt.dev? msg.u0 msg.u1 msg.u2 + * info 0x40 0xd0 0x20 0x01 0xd0 + * + * info mngt: 0x40 0xd0 0x20 0x10 0x00 + * info kbd: 0x40 0xd0 0x20 0x10 0x01 + * info tp: 0x40 0xd0 0x20 0x10 0x02 + * + * desc kbd: 0x40 0xd0 0x20 0x10 0x01 + * desc trackpad: 0x40 0xd0 0x20 0x10 0x02 + * + * mt mode: 0x40 0x02 0x52 0x02 0x00 set protocol? + * capslock led 0x40 0x01 0x51 0x01 0x00 output report + * + * report kbd: 0x20 0x01 0x10 0x01 0x00 input report + * report tp: 0x20 0x02 0x10 0x02 0x00 input report + * + */ + + +static int spihid_apple_request(struct spihid_apple *spihid, u8 target, u8 unk0, + u8 unk1, u8 unk2, u16 resp_len, u8 *buf, + size_t len) +{ + struct spihid_transfer_packet *pkt; + struct spihid_msg_hdr *hdr; + u16 crc; + int err; + + /* know reports are small enoug to fit in a single packet */ + if (len > sizeof(pkt->data) - sizeof(*hdr) - sizeof(__le16)) + return -EINVAL; + + err = mutex_lock_interruptible(&spihid->tx_lock); + if (err < 0) + return err; + + pkt = (struct spihid_transfer_packet *)spihid->tx_buf; + + memset(pkt, 0, sizeof(*pkt)); + pkt->flags = SPIHID_WRITE_PACKET; + pkt->device = target; + pkt->length = cpu_to_le16(sizeof(*hdr) + len + sizeof(__le16)); + + hdr = (struct spihid_msg_hdr *)&pkt->data[0]; + hdr->unknown0 = unk0; + hdr->unknown1 = unk1; + hdr->unknown2 = unk2; + hdr->id = spihid->msg_id++; + hdr->rsplen = cpu_to_le16(resp_len); + hdr->length = cpu_to_le16(len); + + if (len) + memcpy(pkt->data + sizeof(*hdr), buf, len); + crc = crc16(0, &pkt->data[0], sizeof(*hdr) + len); + put_unaligned_le16(crc, pkt->data + sizeof(*hdr) + len); + + pkt->crc16 = cpu_to_le16(crc16(0, spihid->tx_buf, + offsetof(struct spihid_transfer_packet, crc16))); + + memset(spihid->status_buf, 0, sizeof(spi_hid_apple_status_ok)); + + err = spi_sync(spihid->spidev, &spihid->tx_msg); + + if (memcmp(spihid->status_buf, spi_hid_apple_status_ok, + sizeof(spi_hid_apple_status_ok))) { + u8 *b = spihid->status_buf; + dev_warn_ratelimited(&spihid->spidev->dev, "status message " + "mismatch: %02x %02x %02x %02x\n", + b[0], b[1], b[2], b[3]); + } + mutex_unlock(&spihid->tx_lock); + if (err < 0) + return err; + + return (int)len; +} + +static struct spihid_apple *spihid_get_data(struct spihid_interface *idev) +{ + switch (idev->id) { + case SPIHID_DEVICE_ID_KBD: + return container_of(idev, struct spihid_apple, kbd); + case SPIHID_DEVICE_ID_TP: + return container_of(idev, struct spihid_apple, tp); + default: + return NULL; + } +} + +static int apple_ll_start(struct hid_device *hdev) +{ + /* no-op SPI transport is already setup */ + return 0; +}; + +static void apple_ll_stop(struct hid_device *hdev) +{ + /* no-op, devices will be desstroyed on driver destruction */ +} + +static int apple_ll_open(struct hid_device *hdev) +{ + struct spihid_apple *spihid; + struct spihid_interface *idev = hdev->driver_data; + + if (idev->hid_desc_len == 0) { + spihid = spihid_get_data(idev); + dev_warn(&spihid->spidev->dev, + "HID descriptor missing for dev %u", idev->id); + } else + idev->ready = true; + + return 0; +} + +static void apple_ll_close(struct hid_device *hdev) +{ + struct spihid_interface *idev = hdev->driver_data; + idev->ready = false; +} + +static int apple_ll_parse(struct hid_device *hdev) +{ + struct spihid_interface *idev = hdev->driver_data; + + return hid_parse_report(hdev, idev->hid_desc, idev->hid_desc_len); +} + +static int apple_ll_raw_request(struct hid_device *hdev, + unsigned char reportnum, __u8 *buf, size_t len, + unsigned char rtype, int reqtype) +{ + struct spihid_interface *idev = hdev->driver_data; + struct spihid_apple *spihid = spihid_get_data(idev); + int ret; + + dev_dbg(&spihid->spidev->dev, + "apple_ll_raw_request: device:%u reportnum:%hhu rtype:%hhu", + idev->id, reportnum, rtype); + + switch (reqtype) { + case HID_REQ_GET_REPORT: + if (rtype != HID_FEATURE_REPORT) + return -EINVAL; + + idev->reply_len = 0; + ret = spihid_apple_request(spihid, idev->id, 0x32, reportnum, 0x00, len, NULL, 0); + if (ret < 0) + return ret; + + ret = wait_event_interruptible_timeout(spihid->wait, idev->reply_len, + SPIHID_DEF_WAIT); + if (ret == 0) + ret = -ETIMEDOUT; + if (ret < 0) { + dev_err(&spihid->spidev->dev, "waiting for get report failed: %d", ret); + return ret; + } + memcpy(buf, idev->reply_buf, max_t(size_t, len, idev->reply_len)); + return idev->reply_len; + + case HID_REQ_SET_REPORT: + if (buf[0] != reportnum) + return -EINVAL; + if (reportnum != idev->id) { + dev_warn(&spihid->spidev->dev, + "device:%u reportnum:" + "%hhu mismatch", + idev->id, reportnum); + return -EINVAL; + } + return spihid_apple_request(spihid, idev->id, 0x52, reportnum, 0x00, 2, buf, len); + default: + return -EIO; + } +} + +static int apple_ll_output_report(struct hid_device *hdev, __u8 *buf, + size_t len) +{ + struct spihid_interface *idev = hdev->driver_data; + struct spihid_apple *spihid = spihid_get_data(idev); + if (!spihid) + return -1; + + dev_dbg(&spihid->spidev->dev, + "apple_ll_output_report: device:%u len:%zu:", + idev->id, len); + // second idev->id should maybe be buf[0]? + return spihid_apple_request(spihid, idev->id, 0x51, idev->id, 0x00, 0, buf, len); +} + +static struct hid_ll_driver apple_hid_ll = { + .start = &apple_ll_start, + .stop = &apple_ll_stop, + .open = &apple_ll_open, + .close = &apple_ll_close, + .parse = &apple_ll_parse, + .raw_request = &apple_ll_raw_request, + .output_report = &apple_ll_output_report, + .max_buffer_size = SPIHID_MAX_INPUT_REPORT_SIZE, +}; + +static struct spihid_interface *spihid_get_iface(struct spihid_apple *spihid, + u32 iface) +{ + switch (iface) { + case SPIHID_DEVICE_ID_MNGT: + return &spihid->mngt; + case SPIHID_DEVICE_ID_KBD: + return &spihid->kbd; + case SPIHID_DEVICE_ID_TP: + return &spihid->tp; + default: + return NULL; + } +} + +static int spihid_verify_msg(struct spihid_apple *spihid, u8 *buf, size_t len) +{ + u16 msg_crc, crc; + struct device *dev = &spihid->spidev->dev; + + crc = crc16(0, buf, len - sizeof(__le16)); + msg_crc = get_unaligned_le16(buf + len - sizeof(__le16)); + if (crc != msg_crc) { + dev_warn_ratelimited(dev, "Read message crc mismatch\n"); + return 0; + } + return 1; +} + +static bool spihid_status_report(struct spihid_apple *spihid, u8 *pl, + size_t len) +{ + struct device *dev = &spihid->spidev->dev; + dev_dbg(dev, "%s: len: %zu", __func__, len); + if (len == 5 && pl[0] == 0xe0) + return true; + + return false; +} + +static bool spihid_process_input_report(struct spihid_apple *spihid, u32 device, + struct spihid_msg_hdr *hdr, u8 *payload, + size_t len) +{ + //dev_dbg(&spihid>spidev->dev, "input report: req:%hx iface:%u ", hdr->unknown0, device); + if (hdr->unknown0 != 0x10) + return false; + + /* HID device as well but Vendor usage only, handle it internally for now */ + if (device == 0) { + if (hdr->unknown1 == 0xe0) { + return spihid_status_report(spihid, payload, len); + } + } else if (device < SPIHID_MAX_DEVICES) { + struct spihid_interface *iface = + spihid_get_iface(spihid, device); + if (iface && iface->hid && iface->ready) { + hid_input_report(iface->hid, HID_INPUT_REPORT, payload, + len, 1); + return true; + } + } else + dev_dbg(&spihid->spidev->dev, + "unexpected iface:%u for input report", device); + + return false; +} + +struct spihid_device_info { + __le16 u0[2]; + __le16 num_devices; + __le16 vendor_id; + __le16 product_id; + __le16 version_number; + __le16 vendor_str[2]; //< offset and string length + __le16 product_str[2]; //< offset and string length + __le16 serial_str[2]; //< offset and string length +}; + +static bool spihid_process_device_info(struct spihid_apple *spihid, u32 iface, + u8 *payload, size_t len) +{ + struct device *dev = &spihid->spidev->dev; + + if (iface != SPIHID_DEVICE_ID_INFO) + return false; + + if (spihid->vendor_id == 0 && + len >= sizeof(struct spihid_device_info)) { + struct spihid_device_info *info = + (struct spihid_device_info *)payload; + u16 voff, vlen, poff, plen, soff, slen; + u32 num_devices; + + num_devices = __le16_to_cpu(info->num_devices); + + if (num_devices < SPIHID_MAX_DEVICES) { + dev_err(dev, + "Device info reports %u devices, expecting at least 3", + num_devices); + return false; + } + spihid->num_devices = num_devices; + + if (spihid->num_devices > SPIHID_MAX_DEVICES) { + dev_info( + dev, + "limiting the number of devices to mngt, kbd and mouse"); + spihid->num_devices = SPIHID_MAX_DEVICES; + } + + spihid->vendor_id = __le16_to_cpu(info->vendor_id); + spihid->product_id = __le16_to_cpu(info->product_id); + spihid->version_number = __le16_to_cpu(info->version_number); + + voff = __le16_to_cpu(info->vendor_str[0]); + vlen = __le16_to_cpu(info->vendor_str[1]); + + if (voff < len && vlen <= len - voff && + vlen < sizeof(spihid->vendor)) { + memcpy(spihid->vendor, payload + voff, vlen); + spihid->vendor[vlen] = '\0'; + } + + poff = __le16_to_cpu(info->product_str[0]); + plen = __le16_to_cpu(info->product_str[1]); + + if (poff < len && plen <= len - poff && + plen < sizeof(spihid->product)) { + memcpy(spihid->product, payload + poff, plen); + spihid->product[plen] = '\0'; + } + + soff = __le16_to_cpu(info->serial_str[0]); + slen = __le16_to_cpu(info->serial_str[1]); + + if (soff < len && slen <= len - soff && + slen < sizeof(spihid->serial)) { + memcpy(spihid->vendor, payload + soff, slen); + spihid->serial[slen] = '\0'; + } + + wake_up_interruptible(&spihid->wait); + } + return true; +} + +struct spihid_iface_info { + u8 u_0; + u8 interface_num; + u8 u_2; + u8 u_3; + u8 u_4; + u8 country_code; + __le16 max_input_report_len; + __le16 max_output_report_len; + __le16 max_control_report_len; + __le16 name_offset; + __le16 name_length; +}; + +static bool spihid_process_iface_info(struct spihid_apple *spihid, u32 num, + u8 *payload, size_t len) +{ + struct spihid_iface_info *info; + struct spihid_interface *iface = spihid_get_iface(spihid, num); + u32 name_off, name_len; + + if (!iface) + return false; + + if (!iface->max_input_report_len) { + if (len < sizeof(*info)) + return false; + + info = (struct spihid_iface_info *)payload; + + iface->max_input_report_len = + le16_to_cpu(info->max_input_report_len); + iface->max_output_report_len = + le16_to_cpu(info->max_output_report_len); + iface->max_control_report_len = + le16_to_cpu(info->max_control_report_len); + iface->country = info->country_code; + + name_off = le16_to_cpu(info->name_offset); + name_len = le16_to_cpu(info->name_length); + + if (name_off < len && name_len <= len - name_off && + name_len < sizeof(iface->name)) { + memcpy(iface->name, payload + name_off, name_len); + iface->name[name_len] = '\0'; + } + + dev_dbg(&spihid->spidev->dev, "Info for %s, country code: 0x%x", + iface->name, iface->country); + + wake_up_interruptible(&spihid->wait); + } + + return true; +} + +static int spihid_register_hid_device(struct spihid_apple *spihid, + struct spihid_interface *idev, u8 device); + +static bool spihid_process_iface_hid_report_desc(struct spihid_apple *spihid, + u32 num, u8 *payload, + size_t len) +{ + struct spihid_interface *iface = spihid_get_iface(spihid, num); + + if (!iface) + return false; + + if (iface->hid_desc_len == 0) { + if (len > SPIHID_DESC_MAX) + return false; + memcpy(iface->hid_desc, payload, len); + iface->hid_desc_len = len; + + /* do not register the mngt iface as HID device */ + if (num > 0) + spihid_register_hid_device(spihid, iface, num); + + wake_up_interruptible(&spihid->wait); + } + return true; +} + +static bool spihid_process_iface_get_report(struct spihid_apple *spihid, + u32 device, u8 report, + u8 *payload, size_t len) +{ + struct spihid_interface *iface = spihid_get_iface(spihid, device); + + if (!iface) + return false; + + if (len > sizeof(iface->reply_buf) || len < 1) + return false; + + memcpy(iface->reply_buf, payload, len); + iface->reply_len = len; + + wake_up_interruptible(&spihid->wait); + + return true; +} + +static bool spihid_process_response(struct spihid_apple *spihid, u32 device, + struct spihid_msg_hdr *hdr, u8 *payload, + size_t len) +{ + if (hdr->unknown0 == 0x20) { + switch (hdr->unknown1) { + case 0x01: + return spihid_process_device_info(spihid, hdr->unknown2, + payload, len); + case 0x02: + return spihid_process_iface_info(spihid, hdr->unknown2, + payload, len); + case 0x10: + return spihid_process_iface_hid_report_desc( + spihid, hdr->unknown2, payload, len); + default: + break; + } + } + + if (hdr->unknown0 == 0x32) { + return spihid_process_iface_get_report(spihid, device, hdr->unknown1, payload, len); + } + + return false; +} + +static void spihid_process_message(struct spihid_apple *spihid, u8 *data, + size_t length, u8 device, u8 flags) +{ + struct device *dev = &spihid->spidev->dev; + struct spihid_msg_hdr *hdr; + bool handled = false; + size_t payload_len; + u8 *payload; + + if (!spihid_verify_msg(spihid, data, length)) + return; + + hdr = (struct spihid_msg_hdr *)data; + payload_len = le16_to_cpu(hdr->length); + + if (payload_len == 0 || + (payload_len + sizeof(struct spihid_msg_hdr) + 2) > length) + return; + + payload = data + sizeof(struct spihid_msg_hdr); + + switch (flags) { + case SPIHID_READ_PACKET: + handled = spihid_process_input_report(spihid, device, hdr, + payload, payload_len); + break; + case SPIHID_WRITE_PACKET: + handled = spihid_process_response(spihid, device, hdr, payload, + payload_len); + break; + default: + break; + } + +#if defined(DEBUG) && DEBUG > 1 + { + dev_dbg(dev, + "R msg: req:%02hhx rep:%02hhx dev:%02hhx id:%hu len:%hu\n", + hdr->unknown0, hdr->unknown1, hdr->unknown2, hdr->id, + hdr->length); + print_hex_dump_debug("spihid msg: ", DUMP_PREFIX_OFFSET, 16, 1, + payload, le16_to_cpu(hdr->length), true); + } +#else + if (!handled) { + dev_dbg(dev, + "R unhandled msg: req:%02hhx rep:%02hhx dev:%02hhx id:%hu len:%hu\n", + hdr->unknown0, hdr->unknown1, hdr->unknown2, hdr->id, + hdr->length); + print_hex_dump_debug("spihid msg: ", DUMP_PREFIX_OFFSET, 16, 1, + payload, le16_to_cpu(hdr->length), true); + } +#endif +} + +static void spihid_assemble_message(struct spihid_apple *spihid, + struct spihid_transfer_packet *pkt) +{ + size_t length, offset, remain; + struct device *dev = &spihid->spidev->dev; + struct spihid_input_report *rep = &spihid->report; + + length = le16_to_cpu(pkt->length); + remain = le16_to_cpu(pkt->remain); + offset = le16_to_cpu(pkt->offset); + + if (offset + length + remain > U16_MAX) { + return; + } + + if (pkt->device != rep->device || pkt->flags != rep->flags || + offset != rep->offset) { + rep->device = 0; + rep->flags = 0; + rep->offset = 0; + rep->length = 0; + } + + if (offset == 0) { + if (rep->offset != 0) { + dev_warn(dev, "incomplete report off:%u len:%u", + rep->offset, rep->length); + } + memcpy(rep->buf, pkt->data, length); + rep->offset = length; + rep->length = length + remain; + rep->device = pkt->device; + rep->flags = pkt->flags; + } else if (offset == rep->offset) { + if (offset + length + remain != rep->length) { + dev_warn(dev, "incomplete report off:%u len:%u", + rep->offset, rep->length); + return; + } + memcpy(rep->buf + offset, pkt->data, length); + rep->offset += length; + + if (rep->offset == rep->length) { + spihid_process_message(spihid, rep->buf, rep->length, + rep->device, rep->flags); + rep->device = 0; + rep->flags = 0; + rep->offset = 0; + rep->length = 0; + } + } +} + +static void spihid_process_read(struct spihid_apple *spihid) +{ + u16 crc; + size_t length; + struct device *dev = &spihid->spidev->dev; + struct spihid_transfer_packet *pkt; + + pkt = (struct spihid_transfer_packet *)spihid->rx_buf; + + /* check transfer packet crc */ + crc = crc16(0, spihid->rx_buf, + offsetof(struct spihid_transfer_packet, crc16)); + if (crc != le16_to_cpu(pkt->crc16)) { + dev_warn_ratelimited(dev, "Read package crc mismatch\n"); + return; + } + + length = le16_to_cpu(pkt->length); + + if (length < sizeof(struct spihid_msg_hdr) + 2) { + if (length == sizeof(spi_hid_apple_booted) && + !memcmp(pkt->data, spi_hid_apple_booted, length)) { + if (!spihid->status_booted) { + spihid->status_booted = true; + wake_up_interruptible(&spihid->wait); + } + } else { + dev_info(dev, "R short packet: len:%zu\n", length); + print_hex_dump(KERN_INFO, "spihid pkt:", + DUMP_PREFIX_OFFSET, 16, 1, pkt->data, + length, false); + } + return; + } + +#if defined(DEBUG) && DEBUG > 1 + dev_dbg(dev, + "R pkt: flags:%02hhx dev:%02hhx off:%hu remain:%hu, len:%zu\n", + pkt->flags, pkt->device, pkt->offset, pkt->remain, length); +#if defined(DEBUG) && DEBUG > 2 + print_hex_dump_debug("spihid pkt: ", DUMP_PREFIX_OFFSET, 16, 1, + spihid->rx_buf, + sizeof(struct spihid_transfer_packet), true); +#endif +#endif + + if (length > sizeof(pkt->data)) { + dev_warn_ratelimited(dev, "Invalid pkt len:%zu", length); + return; + } + + /* short message */ + if (pkt->offset == 0 && pkt->remain == 0) { + spihid_process_message(spihid, pkt->data, length, pkt->device, + pkt->flags); + } else { + spihid_assemble_message(spihid, pkt); + } +} + +static void spihid_read_packet_sync(struct spihid_apple *spihid) +{ + int err; + + err = spi_sync(spihid->spidev, &spihid->rx_msg); + if (!err) { + spihid_process_read(spihid); + } else { + dev_warn(&spihid->spidev->dev, "RX failed: %d\n", err); + } +} + +irqreturn_t spihid_apple_core_irq(int irq, void *data) +{ + struct spi_device *spi = data; + struct spihid_apple *spihid = spi_get_drvdata(spi); + + spihid_read_packet_sync(spihid); + + return IRQ_HANDLED; +} +EXPORT_SYMBOL_GPL(spihid_apple_core_irq); + +static void spihid_apple_setup_spi_msgs(struct spihid_apple *spihid) +{ + memset(&spihid->rx_transfer, 0, sizeof(spihid->rx_transfer)); + + spihid->rx_transfer.rx_buf = spihid->rx_buf; + spihid->rx_transfer.len = sizeof(struct spihid_transfer_packet); + + spi_message_init(&spihid->rx_msg); + spi_message_add_tail(&spihid->rx_transfer, &spihid->rx_msg); + + memset(&spihid->tx_transfer, 0, sizeof(spihid->rx_transfer)); + memset(&spihid->status_transfer, 0, sizeof(spihid->status_transfer)); + + spihid->tx_transfer.tx_buf = spihid->tx_buf; + spihid->tx_transfer.len = sizeof(struct spihid_transfer_packet); + spihid->tx_transfer.delay.unit = SPI_DELAY_UNIT_USECS; + spihid->tx_transfer.delay.value = SPI_RW_CHG_DELAY_US; + + spihid->status_transfer.rx_buf = spihid->status_buf; + spihid->status_transfer.len = sizeof(spi_hid_apple_status_ok); + + spi_message_init(&spihid->tx_msg); + spi_message_add_tail(&spihid->tx_transfer, &spihid->tx_msg); + spi_message_add_tail(&spihid->status_transfer, &spihid->tx_msg); +} + +static int spihid_apple_setup_spi(struct spihid_apple *spihid) +{ + spihid_apple_setup_spi_msgs(spihid); + + return spihid->ops->power_on(spihid->ops); +} + +static int spihid_register_hid_device(struct spihid_apple *spihid, + struct spihid_interface *iface, u8 device) +{ + int ret; + char *suffix; + struct hid_device *hid; + + iface->id = device; + + hid = hid_allocate_device(); + if (IS_ERR(hid)) + return PTR_ERR(hid); + + /* + * Use 'Apple SPI Keyboard' and 'Apple SPI Trackpad' as input device + * names. The device names need to be distinct since at least Kwin uses + * the tripple Vendor ID, Product ID, Name to identify devices. + */ + snprintf(hid->name, sizeof(hid->name), "Apple SPI %s", iface->name); + // strip ' / Boot' suffix from the name + suffix = strstr(hid->name, " / Boot"); + if (suffix) + suffix[0] = '\0'; + snprintf(hid->phys, sizeof(hid->phys), "%s (%hhx)", + dev_name(&spihid->spidev->dev), device); + strscpy(hid->uniq, spihid->serial, sizeof(hid->uniq)); + + hid->ll_driver = &apple_hid_ll; + hid->bus = BUS_SPI; + hid->vendor = spihid->vendor_id; + hid->product = spihid->product_id; + hid->version = spihid->version_number; + + if (device == SPIHID_DEVICE_ID_KBD) + hid->type = HID_TYPE_SPI_KEYBOARD; + else if (device == SPIHID_DEVICE_ID_TP) + hid->type = HID_TYPE_SPI_MOUSE; + + hid->country = iface->country; + hid->dev.parent = &spihid->spidev->dev; + hid->driver_data = iface; + + ret = hid_add_device(hid); + if (ret < 0) { + hid_destroy_device(hid); + dev_warn(&spihid->spidev->dev, + "Failed to register hid device %hhu", device); + return ret; + } + + iface->hid = hid; + + return 0; +} + +static void spihid_destroy_hid_device(struct spihid_interface *iface) +{ + if (iface->hid) { + hid_destroy_device(iface->hid); + iface->hid = NULL; + } + iface->ready = false; +} + +int spihid_apple_core_probe(struct spi_device *spi, struct spihid_apple_ops *ops) +{ + struct device *dev = &spi->dev; + struct spihid_apple *spihid; + int err, i; + + if (!ops || !ops->power_on || !ops->power_off || !ops->enable_irq || !ops->disable_irq) + return -EINVAL; + + spihid = devm_kzalloc(dev, sizeof(*spihid), GFP_KERNEL); + if (!spihid) + return -ENOMEM; + + spihid->ops = ops; + spihid->spidev = spi; + + // init spi + spi_set_drvdata(spi, spihid); + + /* + * allocate SPI buffers + * Overallocate the receice buffer since it passed directly into + * hid_input_report / hid_report_raw_event. The later expects the buffer + * to be HID_MAX_BUFFER_SIZE (16k) or hid_ll_driver.max_buffer_size if + * set. + */ + spihid->rx_buf = devm_kmalloc( + &spi->dev, SPIHID_MAX_INPUT_REPORT_SIZE, GFP_KERNEL); + spihid->tx_buf = devm_kmalloc( + &spi->dev, sizeof(struct spihid_transfer_packet), GFP_KERNEL); + spihid->status_buf = devm_kmalloc( + &spi->dev, sizeof(spi_hid_apple_status_ok), GFP_KERNEL); + + if (!spihid->rx_buf || !spihid->tx_buf || !spihid->status_buf) + return -ENOMEM; + + spihid->report.buf = + devm_kmalloc(dev, SPIHID_MAX_INPUT_REPORT_SIZE, GFP_KERNEL); + + spihid->kbd.hid_desc = devm_kmalloc(dev, SPIHID_DESC_MAX, GFP_KERNEL); + spihid->tp.hid_desc = devm_kmalloc(dev, SPIHID_DESC_MAX, GFP_KERNEL); + + if (!spihid->report.buf || !spihid->kbd.hid_desc || + !spihid->tp.hid_desc) + return -ENOMEM; + + init_waitqueue_head(&spihid->wait); + + mutex_init(&spihid->tx_lock); + + /* Init spi transfer buffers and power device on */ + err = spihid_apple_setup_spi(spihid); + if (err < 0) + goto error; + + /* enable HID irq */ + spihid->ops->enable_irq(spihid->ops); + + // wait for boot message + err = wait_event_interruptible_timeout(spihid->wait, + spihid->status_booted, + msecs_to_jiffies(1000)); + if (err == 0) + err = -ENODEV; + if (err < 0) { + dev_err(dev, "waiting for device boot failed: %d", err); + goto error; + } + + /* request device information */ + dev_dbg(dev, "request device info"); + spihid_apple_request(spihid, 0xd0, 0x20, 0x01, 0xd0, 0, NULL, 0); + err = wait_event_interruptible_timeout(spihid->wait, spihid->vendor_id, + SPIHID_DEF_WAIT); + if (err == 0) + err = -ENODEV; + if (err < 0) { + dev_err(dev, "waiting for device info failed: %d", err); + goto error; + } + + /* request interface information */ + for (i = 0; i < spihid->num_devices; i++) { + struct spihid_interface *iface = spihid_get_iface(spihid, i); + if (!iface) + continue; + dev_dbg(dev, "request interface info 0x%02x", i); + spihid_apple_request(spihid, 0xd0, 0x20, 0x02, i, + SPIHID_DESC_MAX, NULL, 0); + err = wait_event_interruptible_timeout( + spihid->wait, iface->max_input_report_len, + SPIHID_DEF_WAIT); + } + + /* request HID report descriptors */ + for (i = 1; i < spihid->num_devices; i++) { + struct spihid_interface *iface = spihid_get_iface(spihid, i); + if (!iface) + continue; + dev_dbg(dev, "request hid report desc 0x%02x", i); + spihid_apple_request(spihid, 0xd0, 0x20, 0x10, i, + SPIHID_DESC_MAX, NULL, 0); + wait_event_interruptible_timeout( + spihid->wait, iface->hid_desc_len, SPIHID_DEF_WAIT); + } + + return 0; +error: + return err; +} +EXPORT_SYMBOL_GPL(spihid_apple_core_probe); + +void spihid_apple_core_remove(struct spi_device *spi) +{ + struct spihid_apple *spihid = spi_get_drvdata(spi); + + /* destroy input devices */ + + spihid_destroy_hid_device(&spihid->tp); + spihid_destroy_hid_device(&spihid->kbd); + + /* disable irq */ + spihid->ops->disable_irq(spihid->ops); + + /* power SPI device down */ + spihid->ops->power_off(spihid->ops); +} +EXPORT_SYMBOL_GPL(spihid_apple_core_remove); + +void spihid_apple_core_shutdown(struct spi_device *spi) +{ + struct spihid_apple *spihid = spi_get_drvdata(spi); + + /* disable irq */ + spihid->ops->disable_irq(spihid->ops); + + /* power SPI device down */ + spihid->ops->power_off(spihid->ops); +} +EXPORT_SYMBOL_GPL(spihid_apple_core_shutdown); + +#ifdef CONFIG_PM_SLEEP +static int spihid_apple_core_suspend(struct device *dev) +{ + int ret; +#ifdef IRQ_WAKE_SUPPORT + int wake_status; +#endif + struct spihid_apple *spihid = spi_get_drvdata(to_spi_device(dev)); + + if (spihid->tp.hid) { + ret = hid_driver_suspend(spihid->tp.hid, PMSG_SUSPEND); + if (ret < 0) + return ret; + } + + if (spihid->kbd.hid) { + ret = hid_driver_suspend(spihid->kbd.hid, PMSG_SUSPEND); + if (ret < 0) { + if (spihid->tp.hid) + hid_driver_resume(spihid->tp.hid); + return ret; + } + } + + /* Save some power */ + spihid->ops->disable_irq(spihid->ops); + +#ifdef IRQ_WAKE_SUPPORT + if (device_may_wakeup(dev)) { + wake_status = spihid->ops->enable_irq_wake(spihid->ops); + if (!wake_status) + spihid->irq_wake_enabled = true; + else + dev_warn(dev, "Failed to enable irq wake: %d\n", + wake_status); + } else { + spihid->ops->power_off(spihid->ops); + } +#else + spihid->ops->power_off(spihid->ops); +#endif + + return 0; +} + +static int spihid_apple_core_resume(struct device *dev) +{ + int ret_tp = 0, ret_kbd = 0; + struct spihid_apple *spihid = spi_get_drvdata(to_spi_device(dev)); +#ifdef IRQ_WAKE_SUPPORT + int wake_status; + + if (!device_may_wakeup(dev)) { + spihid->ops->power_on(spihid->ops); + } else if (spihid->irq_wake_enabled) { + wake_status = spihid->ops->disable_irq_wake(spihid->ops); + if (!wake_status) + spihid->irq_wake_enabled = false; + else + dev_warn(dev, "Failed to disable irq wake: %d\n", + wake_status); + } +#endif + + spihid->ops->enable_irq(spihid->ops); + spihid->ops->power_on(spihid->ops); + + if (spihid->tp.hid) + ret_tp = hid_driver_reset_resume(spihid->tp.hid); + if (spihid->kbd.hid) + ret_kbd = hid_driver_reset_resume(spihid->kbd.hid); + + if (ret_tp < 0) + return ret_tp; + + return ret_kbd; +} +#endif + +const struct dev_pm_ops spihid_apple_core_pm = { + SET_SYSTEM_SLEEP_PM_OPS(spihid_apple_core_suspend, + spihid_apple_core_resume) +}; +EXPORT_SYMBOL_GPL(spihid_apple_core_pm); + +MODULE_DESCRIPTION("Apple SPI HID transport driver"); +MODULE_AUTHOR("Janne Grunau "); +MODULE_LICENSE("GPL"); diff --git a/drivers/hid/spi-hid/spi-hid-apple-of.c b/drivers/hid/spi-hid/spi-hid-apple-of.c new file mode 100644 index 000000000000..3f87b299351d --- /dev/null +++ b/drivers/hid/spi-hid/spi-hid-apple-of.c @@ -0,0 +1,151 @@ +/* + * SPDX-License-Identifier: GPL-2.0 + * + * Apple SPI HID transport driver - Open Firmware + * + * Copyright (C) The Asahi Linux Contributors + */ + +#include +#include +#include +#include + +#include "spi-hid-apple.h" + + +struct spihid_apple_of { + struct spihid_apple_ops ops; + + struct gpio_desc *enable_gpio; + int irq; +}; + +static int spihid_apple_of_power_on(struct spihid_apple_ops *ops) +{ + struct spihid_apple_of *sh_of = container_of(ops, struct spihid_apple_of, ops); + + /* reset the controller on boot */ + gpiod_direction_output(sh_of->enable_gpio, 1); + msleep(5); + gpiod_direction_output(sh_of->enable_gpio, 0); + msleep(5); + /* turn SPI device on */ + gpiod_direction_output(sh_of->enable_gpio, 1); + msleep(50); + + return 0; +} + +static int spihid_apple_of_power_off(struct spihid_apple_ops *ops) +{ + struct spihid_apple_of *sh_of = container_of(ops, struct spihid_apple_of, ops); + + /* turn SPI device off */ + gpiod_direction_output(sh_of->enable_gpio, 0); + + return 0; +} + +static int spihid_apple_of_enable_irq(struct spihid_apple_ops *ops) +{ + struct spihid_apple_of *sh_of = container_of(ops, struct spihid_apple_of, ops); + + enable_irq(sh_of->irq); + + return 0; +} + +static int spihid_apple_of_disable_irq(struct spihid_apple_ops *ops) +{ + struct spihid_apple_of *sh_of = container_of(ops, struct spihid_apple_of, ops); + + disable_irq(sh_of->irq); + + return 0; +} + +static int spihid_apple_of_enable_irq_wake(struct spihid_apple_ops *ops) +{ + struct spihid_apple_of *sh_of = container_of(ops, struct spihid_apple_of, ops); + + return enable_irq_wake(sh_of->irq); +} + +static int spihid_apple_of_disable_irq_wake(struct spihid_apple_ops *ops) +{ + struct spihid_apple_of *sh_of = container_of(ops, struct spihid_apple_of, ops); + + return disable_irq_wake(sh_of->irq); +} + +static int spihid_apple_of_probe(struct spi_device *spi) +{ + struct device *dev = &spi->dev; + struct spihid_apple_of *spihid_of; + int err; + + spihid_of = devm_kzalloc(dev, sizeof(*spihid_of), GFP_KERNEL); + if (!spihid_of) + return -ENOMEM; + + spihid_of->ops.power_on = spihid_apple_of_power_on; + spihid_of->ops.power_off = spihid_apple_of_power_off; + spihid_of->ops.enable_irq = spihid_apple_of_enable_irq; + spihid_of->ops.disable_irq = spihid_apple_of_disable_irq; + spihid_of->ops.enable_irq_wake = spihid_apple_of_enable_irq_wake; + spihid_of->ops.disable_irq_wake = spihid_apple_of_disable_irq_wake; + + spihid_of->enable_gpio = devm_gpiod_get_index(dev, "spien", 0, 0); + if (IS_ERR(spihid_of->enable_gpio)) { + err = PTR_ERR(spihid_of->enable_gpio); + dev_err(dev, "failed to get 'spien' gpio pin: %d", err); + return err; + } + + spihid_of->irq = of_irq_get(dev->of_node, 0); + if (spihid_of->irq < 0) { + err = spihid_of->irq; + dev_err(dev, "failed to get 'extended-irq': %d", err); + return err; + } + err = devm_request_threaded_irq(dev, spihid_of->irq, NULL, + spihid_apple_core_irq, IRQF_ONESHOT | IRQF_NO_AUTOEN, + "spi-hid-apple-irq", spi); + if (err < 0) { + dev_err(dev, "failed to request extended-irq %d: %d", + spihid_of->irq, err); + return err; + } + + return spihid_apple_core_probe(spi, &spihid_of->ops); +} + +static const struct of_device_id spihid_apple_of_match[] = { + { .compatible = "apple,spi-hid-transport" }, + {}, +}; +MODULE_DEVICE_TABLE(of, spihid_apple_of_match); + +static struct spi_device_id spihid_apple_of_id[] = { + { "spi-hid-transport", 0 }, + {} +}; +MODULE_DEVICE_TABLE(spi, spihid_apple_of_id); + +static struct spi_driver spihid_apple_of_driver = { + .driver = { + .name = "spi-hid-apple-of", + .pm = &spihid_apple_core_pm, + .of_match_table = of_match_ptr(spihid_apple_of_match), + }, + + .id_table = spihid_apple_of_id, + .probe = spihid_apple_of_probe, + .remove = spihid_apple_core_remove, + .shutdown = spihid_apple_core_shutdown, +}; + +module_spi_driver(spihid_apple_of_driver); + +MODULE_LICENSE("GPL"); diff --git a/drivers/hid/spi-hid/spi-hid-apple.h b/drivers/hid/spi-hid/spi-hid-apple.h new file mode 100644 index 000000000000..9abecd1ba780 --- /dev/null +++ b/drivers/hid/spi-hid/spi-hid-apple.h @@ -0,0 +1,35 @@ +/* SPDX-License-Identifier: GPL-2.0-only OR MIT */ + +#ifndef SPI_HID_APPLE_H +#define SPI_HID_APPLE_H + +#include +#include + +/** + * struct spihid_apple_ops - Ops to control the device from the core driver. + * + * @power_on: reset and power the device on. + * @power_off: power the device off. + * @enable_irq: enable irq or ACPI gpe. + * @disable_irq: disable irq or ACPI gpe. + */ + +struct spihid_apple_ops { + int (*power_on)(struct spihid_apple_ops *ops); + int (*power_off)(struct spihid_apple_ops *ops); + int (*enable_irq)(struct spihid_apple_ops *ops); + int (*disable_irq)(struct spihid_apple_ops *ops); + int (*enable_irq_wake)(struct spihid_apple_ops *ops); + int (*disable_irq_wake)(struct spihid_apple_ops *ops); +}; + +irqreturn_t spihid_apple_core_irq(int irq, void *data); + +int spihid_apple_core_probe(struct spi_device *spi, struct spihid_apple_ops *ops); +void spihid_apple_core_remove(struct spi_device *spi); +void spihid_apple_core_shutdown(struct spi_device *spi); + +extern const struct dev_pm_ops spihid_apple_core_pm; + +#endif /* SPI_HID_APPLE_H */ diff --git a/drivers/hwmon/applesmc.c b/drivers/hwmon/applesmc.c index fc6d6a9053ce..698f44794453 100644 --- a/drivers/hwmon/applesmc.c +++ b/drivers/hwmon/applesmc.c @@ -6,6 +6,7 @@ * * Copyright (C) 2007 Nicolas Boichat * Copyright (C) 2010 Henrik Rydberg + * Copyright (C) 2019 Paul Pawlowski * * Based on hdaps.c driver: * Copyright (C) 2005 Robert Love @@ -18,7 +19,7 @@ #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt #include -#include +#include #include #include #include @@ -35,12 +36,24 @@ #include /* data port used by Apple SMC */ -#define APPLESMC_DATA_PORT 0x300 +#define APPLESMC_DATA_PORT 0 /* command/status port used by Apple SMC */ -#define APPLESMC_CMD_PORT 0x304 +#define APPLESMC_CMD_PORT 4 #define APPLESMC_NR_PORTS 32 /* 0x300-0x31f */ +#define APPLESMC_IOMEM_KEY_DATA 0 +#define APPLESMC_IOMEM_KEY_STATUS 0x4005 +#define APPLESMC_IOMEM_KEY_NAME 0x78 +#define APPLESMC_IOMEM_KEY_DATA_LEN 0x7D +#define APPLESMC_IOMEM_KEY_SMC_ID 0x7E +#define APPLESMC_IOMEM_KEY_CMD 0x7F +#define APPLESMC_IOMEM_MIN_SIZE 0x4006 + +#define APPLESMC_IOMEM_KEY_TYPE_CODE 0 +#define APPLESMC_IOMEM_KEY_TYPE_DATA_LEN 5 +#define APPLESMC_IOMEM_KEY_TYPE_FLAGS 6 + #define APPLESMC_MAX_DATA_LENGTH 32 /* Apple SMC status bits */ @@ -74,6 +87,7 @@ #define FAN_ID_FMT "F%dID" /* r-o char[16] */ #define TEMP_SENSOR_TYPE "sp78" +#define FLOAT_TYPE "flt " /* List of keys used to read/write fan speeds */ static const char *const fan_speed_fmt[] = { @@ -83,6 +97,7 @@ static const char *const fan_speed_fmt[] = { "F%dSf", /* safe speed - not all models */ "F%dTg", /* target speed (manual: rw) */ }; +#define FAN_MANUAL_FMT "F%dMd" #define INIT_TIMEOUT_MSECS 5000 /* wait up to 5s for device init ... */ #define INIT_WAIT_MSECS 50 /* ... in 50ms increments */ @@ -119,7 +134,7 @@ struct applesmc_entry { }; /* Register lookup and registers common to all SMCs */ -static struct applesmc_registers { +struct applesmc_registers { struct mutex mutex; /* register read/write mutex */ unsigned int key_count; /* number of SMC registers */ unsigned int fan_count; /* number of fans */ @@ -133,26 +148,38 @@ static struct applesmc_registers { bool init_complete; /* true when fully initialized */ struct applesmc_entry *cache; /* cached key entries */ const char **index; /* temperature key index */ -} smcreg = { - .mutex = __MUTEX_INITIALIZER(smcreg.mutex), }; -static const int debug; -static struct platform_device *pdev; -static s16 rest_x; -static s16 rest_y; -static u8 backlight_state[2]; +struct applesmc_device { + struct acpi_device *dev; + struct device *ldev; + struct applesmc_registers reg; -static struct device *hwmon_dev; -static struct input_dev *applesmc_idev; + bool port_base_set, iomem_base_set; + u16 port_base; + u8 *__iomem iomem_base; + u32 iomem_base_addr, iomem_base_size; -/* - * Last index written to key_at_index sysfs file, and value to use for all other - * key_at_index_* sysfs files. - */ -static unsigned int key_at_index; + s16 rest_x; + s16 rest_y; + + u8 backlight_state[2]; + + struct device *hwmon_dev; + struct input_dev *idev; + + /* + * Last index written to key_at_index sysfs file, and value to use for all other + * key_at_index_* sysfs files. + */ + unsigned int key_at_index; -static struct workqueue_struct *applesmc_led_wq; + struct workqueue_struct *backlight_wq; + struct work_struct backlight_work; + struct led_classdev backlight_dev; +}; + +static const int debug; /* * Wait for specific status bits with a mask on the SMC. @@ -162,7 +189,7 @@ static struct workqueue_struct *applesmc_led_wq; * run out past 500ms. */ -static int wait_status(u8 val, u8 mask) +static int port_wait_status(struct applesmc_device *smc, u8 val, u8 mask) { u8 status; int us; @@ -170,7 +197,7 @@ static int wait_status(u8 val, u8 mask) us = APPLESMC_MIN_WAIT; for (i = 0; i < 24 ; i++) { - status = inb(APPLESMC_CMD_PORT); + status = inb(smc->port_base + APPLESMC_CMD_PORT); if ((status & mask) == val) return 0; usleep_range(us, us * 2); @@ -180,13 +207,13 @@ static int wait_status(u8 val, u8 mask) return -EIO; } -/* send_byte - Write to SMC data port. Callers must hold applesmc_lock. */ +/* port_send_byte - Write to SMC data port. Callers must hold applesmc_lock. */ -static int send_byte(u8 cmd, u16 port) +static int port_send_byte(struct applesmc_device *smc, u8 cmd, u16 port) { int status; - status = wait_status(0, SMC_STATUS_IB_CLOSED); + status = port_wait_status(smc, 0, SMC_STATUS_IB_CLOSED); if (status) return status; /* @@ -195,24 +222,25 @@ static int send_byte(u8 cmd, u16 port) * this extra read may not happen if status returns both * simultaneously and this would appear to be required. */ - status = wait_status(SMC_STATUS_BUSY, SMC_STATUS_BUSY); + status = port_wait_status(smc, SMC_STATUS_BUSY, SMC_STATUS_BUSY); if (status) return status; - outb(cmd, port); + outb(cmd, smc->port_base + port); return 0; } -/* send_command - Write a command to the SMC. Callers must hold applesmc_lock. */ +/* port_send_command - Write a command to the SMC. Callers must hold applesmc_lock. */ -static int send_command(u8 cmd) +static int port_send_command(struct applesmc_device *smc, u8 cmd) { int ret; - ret = wait_status(0, SMC_STATUS_IB_CLOSED); + ret = port_wait_status(smc, 0, SMC_STATUS_IB_CLOSED); if (ret) return ret; - outb(cmd, APPLESMC_CMD_PORT); + + outb(cmd, smc->port_base + APPLESMC_CMD_PORT); return 0; } @@ -222,110 +250,304 @@ static int send_command(u8 cmd) * If busy is stuck high after the command then the SMC is jammed. */ -static int smc_sane(void) +static int port_smc_sane(struct applesmc_device *smc) { int ret; - ret = wait_status(0, SMC_STATUS_BUSY); + ret = port_wait_status(smc, 0, SMC_STATUS_BUSY); if (!ret) return ret; - ret = send_command(APPLESMC_READ_CMD); + ret = port_send_command(smc, APPLESMC_READ_CMD); if (ret) return ret; - return wait_status(0, SMC_STATUS_BUSY); + return port_wait_status(smc, 0, SMC_STATUS_BUSY); } -static int send_argument(const char *key) +static int port_send_argument(struct applesmc_device *smc, const char *key) { int i; for (i = 0; i < 4; i++) - if (send_byte(key[i], APPLESMC_DATA_PORT)) + if (port_send_byte(smc, key[i], APPLESMC_DATA_PORT)) return -EIO; return 0; } -static int read_smc(u8 cmd, const char *key, u8 *buffer, u8 len) +static int port_read_smc(struct applesmc_device *smc, u8 cmd, const char *key, + u8 *buffer, u8 len) { u8 status, data = 0; int i; int ret; - ret = smc_sane(); + ret = port_smc_sane(smc); if (ret) return ret; - if (send_command(cmd) || send_argument(key)) { + if (port_send_command(smc, cmd) || port_send_argument(smc, key)) { pr_warn("%.4s: read arg fail\n", key); return -EIO; } /* This has no effect on newer (2012) SMCs */ - if (send_byte(len, APPLESMC_DATA_PORT)) { + if (port_send_byte(smc, len, APPLESMC_DATA_PORT)) { pr_warn("%.4s: read len fail\n", key); return -EIO; } for (i = 0; i < len; i++) { - if (wait_status(SMC_STATUS_AWAITING_DATA | SMC_STATUS_BUSY, + if (port_wait_status(smc, + SMC_STATUS_AWAITING_DATA | SMC_STATUS_BUSY, SMC_STATUS_AWAITING_DATA | SMC_STATUS_BUSY)) { pr_warn("%.4s: read data[%d] fail\n", key, i); return -EIO; } - buffer[i] = inb(APPLESMC_DATA_PORT); + buffer[i] = inb(smc->port_base + APPLESMC_DATA_PORT); } /* Read the data port until bit0 is cleared */ for (i = 0; i < 16; i++) { udelay(APPLESMC_MIN_WAIT); - status = inb(APPLESMC_CMD_PORT); + status = inb(smc->port_base + APPLESMC_CMD_PORT); if (!(status & SMC_STATUS_AWAITING_DATA)) break; - data = inb(APPLESMC_DATA_PORT); + data = inb(smc->port_base + APPLESMC_DATA_PORT); } if (i) pr_warn("flushed %d bytes, last value is: %d\n", i, data); - return wait_status(0, SMC_STATUS_BUSY); + return port_wait_status(smc, 0, SMC_STATUS_BUSY); } -static int write_smc(u8 cmd, const char *key, const u8 *buffer, u8 len) +static int port_write_smc(struct applesmc_device *smc, u8 cmd, const char *key, + const u8 *buffer, u8 len) { int i; int ret; - ret = smc_sane(); + ret = port_smc_sane(smc); if (ret) return ret; - if (send_command(cmd) || send_argument(key)) { + if (port_send_command(smc, cmd) || port_send_argument(smc, key)) { pr_warn("%s: write arg fail\n", key); return -EIO; } - if (send_byte(len, APPLESMC_DATA_PORT)) { + if (port_send_byte(smc, len, APPLESMC_DATA_PORT)) { pr_warn("%.4s: write len fail\n", key); return -EIO; } for (i = 0; i < len; i++) { - if (send_byte(buffer[i], APPLESMC_DATA_PORT)) { + if (port_send_byte(smc, buffer[i], APPLESMC_DATA_PORT)) { pr_warn("%s: write data fail\n", key); return -EIO; } } - return wait_status(0, SMC_STATUS_BUSY); + return port_wait_status(smc, 0, SMC_STATUS_BUSY); } -static int read_register_count(unsigned int *count) +static int port_get_smc_key_info(struct applesmc_device *smc, + const char *key, struct applesmc_entry *info) { - __be32 be; int ret; + u8 raw[6]; - ret = read_smc(APPLESMC_READ_CMD, KEY_COUNT_KEY, (u8 *)&be, 4); + ret = port_read_smc(smc, APPLESMC_GET_KEY_TYPE_CMD, key, raw, 6); if (ret) return ret; + info->len = raw[0]; + memcpy(info->type, &raw[1], 4); + info->flags = raw[5]; + return 0; +} + + +/* + * MMIO based communication. + * TODO: Use updated mechanism for cmd timeout/retry + */ + +static void iomem_clear_status(struct applesmc_device *smc) +{ + if (ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_STATUS)) + iowrite8(0, smc->iomem_base + APPLESMC_IOMEM_KEY_STATUS); +} + +static int iomem_wait_read(struct applesmc_device *smc) +{ + u8 status; + int us; + int i; + + us = APPLESMC_MIN_WAIT; + for (i = 0; i < 24 ; i++) { + status = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_STATUS); + if (status & 0x20) + return 0; + usleep_range(us, us * 2); + if (i > 9) + us <<= 1; + } + + dev_warn(smc->ldev, "%s... timeout\n", __func__); + return -EIO; +} + +static int iomem_read_smc(struct applesmc_device *smc, u8 cmd, const char *key, + u8 *buffer, u8 len) +{ + u8 err, remote_len; + u32 key_int = *((u32 *) key); + + iomem_clear_status(smc); + iowrite32(key_int, smc->iomem_base + APPLESMC_IOMEM_KEY_NAME); + iowrite32(0, smc->iomem_base + APPLESMC_IOMEM_KEY_SMC_ID); + iowrite32(cmd, smc->iomem_base + APPLESMC_IOMEM_KEY_CMD); + + if (iomem_wait_read(smc)) + return -EIO; + + err = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_CMD); + if (err != 0) { + dev_warn(smc->ldev, "read_smc_mmio(%x %8x/%.4s) failed: %u\n", + cmd, key_int, key, err); + return -EIO; + } + + if (cmd == APPLESMC_READ_CMD) { + remote_len = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_DATA_LEN); + if (remote_len != len) { + dev_warn(smc->ldev, + "read_smc_mmio(%x %8x/%.4s) failed: buffer length mismatch (remote = %u, requested = %u)\n", + cmd, key_int, key, remote_len, len); + return -EINVAL; + } + } else { + remote_len = len; + } + + memcpy_fromio(buffer, smc->iomem_base + APPLESMC_IOMEM_KEY_DATA, + remote_len); + + dev_dbg(smc->ldev, "read_smc_mmio(%x %8x/%.4s): buflen=%u reslen=%u\n", + cmd, key_int, key, len, remote_len); + print_hex_dump_bytes("read_smc_mmio(): ", DUMP_PREFIX_NONE, buffer, remote_len); + return 0; +} + +static int iomem_get_smc_key_type(struct applesmc_device *smc, const char *key, + struct applesmc_entry *e) +{ + u8 err; + u8 cmd = APPLESMC_GET_KEY_TYPE_CMD; + u32 key_int = *((u32 *) key); + + iomem_clear_status(smc); + iowrite32(key_int, smc->iomem_base + APPLESMC_IOMEM_KEY_NAME); + iowrite32(0, smc->iomem_base + APPLESMC_IOMEM_KEY_SMC_ID); + iowrite32(cmd, smc->iomem_base + APPLESMC_IOMEM_KEY_CMD); + + if (iomem_wait_read(smc)) + return -EIO; + + err = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_CMD); + if (err != 0) { + dev_warn(smc->ldev, "get_smc_key_type_mmio(%.4s) failed: %u\n", key, err); + return -EIO; + } + + e->len = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_TYPE_DATA_LEN); + *((uint32_t *) e->type) = ioread32( + smc->iomem_base + APPLESMC_IOMEM_KEY_TYPE_CODE); + e->flags = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_TYPE_FLAGS); + + dev_dbg(smc->ldev, "get_smc_key_type_mmio(%.4s): len=%u type=%.4s flags=%x\n", + key, e->len, e->type, e->flags); + return 0; +} + +static int iomem_write_smc(struct applesmc_device *smc, u8 cmd, const char *key, + const u8 *buffer, u8 len) +{ + u8 err; + u32 key_int = *((u32 *) key); + + iomem_clear_status(smc); + iowrite32(key_int, smc->iomem_base + APPLESMC_IOMEM_KEY_NAME); + memcpy_toio(smc->iomem_base + APPLESMC_IOMEM_KEY_DATA, buffer, len); + iowrite32(len, smc->iomem_base + APPLESMC_IOMEM_KEY_DATA_LEN); + iowrite32(0, smc->iomem_base + APPLESMC_IOMEM_KEY_SMC_ID); + iowrite32(cmd, smc->iomem_base + APPLESMC_IOMEM_KEY_CMD); + + if (iomem_wait_read(smc)) + return -EIO; + + err = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_CMD); + if (err != 0) { + dev_warn(smc->ldev, "write_smc_mmio(%x %.4s) failed: %u\n", cmd, key, err); + print_hex_dump_bytes("write_smc_mmio(): ", DUMP_PREFIX_NONE, buffer, len); + return -EIO; + } + + dev_dbg(smc->ldev, "write_smc_mmio(%x %.4s): buflen=%u\n", cmd, key, len); + print_hex_dump_bytes("write_smc_mmio(): ", DUMP_PREFIX_NONE, buffer, len); + return 0; +} + + +static int read_smc(struct applesmc_device *smc, const char *key, + u8 *buffer, u8 len) +{ + if (smc->iomem_base_set) + return iomem_read_smc(smc, APPLESMC_READ_CMD, key, buffer, len); + else + return port_read_smc(smc, APPLESMC_READ_CMD, key, buffer, len); +} + +static int write_smc(struct applesmc_device *smc, const char *key, + const u8 *buffer, u8 len) +{ + if (smc->iomem_base_set) + return iomem_write_smc(smc, APPLESMC_WRITE_CMD, key, buffer, len); + else + return port_write_smc(smc, APPLESMC_WRITE_CMD, key, buffer, len); +} + +static int get_smc_key_by_index(struct applesmc_device *smc, + unsigned int index, char *key) +{ + __be32 be; + + be = cpu_to_be32(index); + if (smc->iomem_base_set) + return iomem_read_smc(smc, APPLESMC_GET_KEY_BY_INDEX_CMD, + (const char *) &be, (u8 *) key, 4); + else + return port_read_smc(smc, APPLESMC_GET_KEY_BY_INDEX_CMD, + (const char *) &be, (u8 *) key, 4); +} + +static int get_smc_key_info(struct applesmc_device *smc, const char *key, + struct applesmc_entry *info) +{ + if (smc->iomem_base_set) + return iomem_get_smc_key_type(smc, key, info); + else + return port_get_smc_key_info(smc, key, info); +} + +static int read_register_count(struct applesmc_device *smc, + unsigned int *count) +{ + __be32 be; + int ret; + + ret = read_smc(smc, KEY_COUNT_KEY, (u8 *)&be, 4); + if (ret < 0) + return ret; *count = be32_to_cpu(be); return 0; @@ -338,76 +560,73 @@ static int read_register_count(unsigned int *count) * All functions below are concurrency safe - callers should NOT hold lock. */ -static int applesmc_read_entry(const struct applesmc_entry *entry, - u8 *buf, u8 len) +static int applesmc_read_entry(struct applesmc_device *smc, + const struct applesmc_entry *entry, u8 *buf, u8 len) { int ret; if (entry->len != len) return -EINVAL; - mutex_lock(&smcreg.mutex); - ret = read_smc(APPLESMC_READ_CMD, entry->key, buf, len); - mutex_unlock(&smcreg.mutex); + mutex_lock(&smc->reg.mutex); + ret = read_smc(smc, entry->key, buf, len); + mutex_unlock(&smc->reg.mutex); return ret; } -static int applesmc_write_entry(const struct applesmc_entry *entry, - const u8 *buf, u8 len) +static int applesmc_write_entry(struct applesmc_device *smc, + const struct applesmc_entry *entry, const u8 *buf, u8 len) { int ret; if (entry->len != len) return -EINVAL; - mutex_lock(&smcreg.mutex); - ret = write_smc(APPLESMC_WRITE_CMD, entry->key, buf, len); - mutex_unlock(&smcreg.mutex); + mutex_lock(&smc->reg.mutex); + ret = write_smc(smc, entry->key, buf, len); + mutex_unlock(&smc->reg.mutex); return ret; } -static const struct applesmc_entry *applesmc_get_entry_by_index(int index) +static const struct applesmc_entry *applesmc_get_entry_by_index( + struct applesmc_device *smc, int index) { - struct applesmc_entry *cache = &smcreg.cache[index]; - u8 key[4], info[6]; - __be32 be; + struct applesmc_entry *cache = &smc->reg.cache[index]; + char key[4]; int ret = 0; if (cache->valid) return cache; - mutex_lock(&smcreg.mutex); + mutex_lock(&smc->reg.mutex); if (cache->valid) goto out; - be = cpu_to_be32(index); - ret = read_smc(APPLESMC_GET_KEY_BY_INDEX_CMD, (u8 *)&be, key, 4); + ret = get_smc_key_by_index(smc, index, key); if (ret) goto out; - ret = read_smc(APPLESMC_GET_KEY_TYPE_CMD, key, info, 6); + memcpy(cache->key, key, 4); + + ret = get_smc_key_info(smc, key, cache); if (ret) goto out; - - memcpy(cache->key, key, 4); - cache->len = info[0]; - memcpy(cache->type, &info[1], 4); - cache->flags = info[5]; cache->valid = true; out: - mutex_unlock(&smcreg.mutex); + mutex_unlock(&smc->reg.mutex); if (ret) return ERR_PTR(ret); return cache; } -static int applesmc_get_lower_bound(unsigned int *lo, const char *key) +static int applesmc_get_lower_bound(struct applesmc_device *smc, + unsigned int *lo, const char *key) { - int begin = 0, end = smcreg.key_count; + int begin = 0, end = smc->reg.key_count; const struct applesmc_entry *entry; while (begin != end) { int middle = begin + (end - begin) / 2; - entry = applesmc_get_entry_by_index(middle); + entry = applesmc_get_entry_by_index(smc, middle); if (IS_ERR(entry)) { *lo = 0; return PTR_ERR(entry); @@ -422,16 +641,17 @@ static int applesmc_get_lower_bound(unsigned int *lo, const char *key) return 0; } -static int applesmc_get_upper_bound(unsigned int *hi, const char *key) +static int applesmc_get_upper_bound(struct applesmc_device *smc, + unsigned int *hi, const char *key) { - int begin = 0, end = smcreg.key_count; + int begin = 0, end = smc->reg.key_count; const struct applesmc_entry *entry; while (begin != end) { int middle = begin + (end - begin) / 2; - entry = applesmc_get_entry_by_index(middle); + entry = applesmc_get_entry_by_index(smc, middle); if (IS_ERR(entry)) { - *hi = smcreg.key_count; + *hi = smc->reg.key_count; return PTR_ERR(entry); } if (strcmp(key, entry->key) < 0) @@ -444,50 +664,54 @@ static int applesmc_get_upper_bound(unsigned int *hi, const char *key) return 0; } -static const struct applesmc_entry *applesmc_get_entry_by_key(const char *key) +static const struct applesmc_entry *applesmc_get_entry_by_key( + struct applesmc_device *smc, const char *key) { int begin, end; int ret; - ret = applesmc_get_lower_bound(&begin, key); + ret = applesmc_get_lower_bound(smc, &begin, key); if (ret) return ERR_PTR(ret); - ret = applesmc_get_upper_bound(&end, key); + ret = applesmc_get_upper_bound(smc, &end, key); if (ret) return ERR_PTR(ret); if (end - begin != 1) return ERR_PTR(-EINVAL); - return applesmc_get_entry_by_index(begin); + return applesmc_get_entry_by_index(smc, begin); } -static int applesmc_read_key(const char *key, u8 *buffer, u8 len) +static int applesmc_read_key(struct applesmc_device *smc, + const char *key, u8 *buffer, u8 len) { const struct applesmc_entry *entry; - entry = applesmc_get_entry_by_key(key); + entry = applesmc_get_entry_by_key(smc, key); if (IS_ERR(entry)) return PTR_ERR(entry); - return applesmc_read_entry(entry, buffer, len); + return applesmc_read_entry(smc, entry, buffer, len); } -static int applesmc_write_key(const char *key, const u8 *buffer, u8 len) +static int applesmc_write_key(struct applesmc_device *smc, + const char *key, const u8 *buffer, u8 len) { const struct applesmc_entry *entry; - entry = applesmc_get_entry_by_key(key); + entry = applesmc_get_entry_by_key(smc, key); if (IS_ERR(entry)) return PTR_ERR(entry); - return applesmc_write_entry(entry, buffer, len); + return applesmc_write_entry(smc, entry, buffer, len); } -static int applesmc_has_key(const char *key, bool *value) +static int applesmc_has_key(struct applesmc_device *smc, + const char *key, bool *value) { const struct applesmc_entry *entry; - entry = applesmc_get_entry_by_key(key); + entry = applesmc_get_entry_by_key(smc, key); if (IS_ERR(entry) && PTR_ERR(entry) != -EINVAL) return PTR_ERR(entry); @@ -498,12 +722,13 @@ static int applesmc_has_key(const char *key, bool *value) /* * applesmc_read_s16 - Read 16-bit signed big endian register */ -static int applesmc_read_s16(const char *key, s16 *value) +static int applesmc_read_s16(struct applesmc_device *smc, + const char *key, s16 *value) { u8 buffer[2]; int ret; - ret = applesmc_read_key(key, buffer, 2); + ret = applesmc_read_key(smc, key, buffer, 2); if (ret) return ret; @@ -511,31 +736,68 @@ static int applesmc_read_s16(const char *key, s16 *value) return 0; } +/** + * applesmc_float_to_u32 - Retrieve the integral part of a float. + * This is needed because Apple made fans use float values in the T2. + * The fractional point is not significantly useful though, and the integral + * part can be easily extracted. + */ +static inline u32 applesmc_float_to_u32(u32 d) +{ + u8 sign = (u8) ((d >> 31) & 1); + s32 exp = (s32) ((d >> 23) & 0xff) - 0x7f; + u32 fr = d & ((1u << 23) - 1); + + if (sign || exp < 0) + return 0; + + return (u32) ((1u << exp) + (fr >> (23 - exp))); +} + +/** + * applesmc_u32_to_float - Convert an u32 into a float. + * See applesmc_float_to_u32 for a rationale. + */ +static inline u32 applesmc_u32_to_float(u32 d) +{ + u32 dc = d, bc = 0, exp; + + if (!d) + return 0; + + while (dc >>= 1) + ++bc; + exp = 0x7f + bc; + + return (u32) ((exp << 23) | + ((d << (23 - (exp - 0x7f))) & ((1u << 23) - 1))); +} /* * applesmc_device_init - initialize the accelerometer. Can sleep. */ -static void applesmc_device_init(void) +static void applesmc_device_init(struct applesmc_device *smc) { int total; u8 buffer[2]; - if (!smcreg.has_accelerometer) + if (!smc->reg.has_accelerometer) return; for (total = INIT_TIMEOUT_MSECS; total > 0; total -= INIT_WAIT_MSECS) { - if (!applesmc_read_key(MOTION_SENSOR_KEY, buffer, 2) && + if (!applesmc_read_key(smc, MOTION_SENSOR_KEY, buffer, 2) && (buffer[0] != 0x00 || buffer[1] != 0x00)) return; buffer[0] = 0xe0; buffer[1] = 0x00; - applesmc_write_key(MOTION_SENSOR_KEY, buffer, 2); + applesmc_write_key(smc, MOTION_SENSOR_KEY, buffer, 2); msleep(INIT_WAIT_MSECS); } pr_warn("failed to init the device\n"); } -static int applesmc_init_index(struct applesmc_registers *s) +static int applesmc_init_index(struct applesmc_device *smc, + struct applesmc_registers *s) { const struct applesmc_entry *entry; unsigned int i; @@ -548,7 +810,7 @@ static int applesmc_init_index(struct applesmc_registers *s) return -ENOMEM; for (i = s->temp_begin; i < s->temp_end; i++) { - entry = applesmc_get_entry_by_index(i); + entry = applesmc_get_entry_by_index(smc, i); if (IS_ERR(entry)) continue; if (strcmp(entry->type, TEMP_SENSOR_TYPE)) @@ -562,9 +824,9 @@ static int applesmc_init_index(struct applesmc_registers *s) /* * applesmc_init_smcreg_try - Try to initialize register cache. Idempotent. */ -static int applesmc_init_smcreg_try(void) +static int applesmc_init_smcreg_try(struct applesmc_device *smc) { - struct applesmc_registers *s = &smcreg; + struct applesmc_registers *s = &smc->reg; bool left_light_sensor = false, right_light_sensor = false; unsigned int count; u8 tmp[1]; @@ -573,7 +835,7 @@ static int applesmc_init_smcreg_try(void) if (s->init_complete) return 0; - ret = read_register_count(&count); + ret = read_register_count(smc, &count); if (ret) return ret; @@ -590,35 +852,35 @@ static int applesmc_init_smcreg_try(void) if (!s->cache) return -ENOMEM; - ret = applesmc_read_key(FANS_COUNT, tmp, 1); + ret = applesmc_read_key(smc, FANS_COUNT, tmp, 1); if (ret) return ret; s->fan_count = tmp[0]; if (s->fan_count > 10) s->fan_count = 10; - ret = applesmc_get_lower_bound(&s->temp_begin, "T"); + ret = applesmc_get_lower_bound(smc, &s->temp_begin, "T"); if (ret) return ret; - ret = applesmc_get_lower_bound(&s->temp_end, "U"); + ret = applesmc_get_lower_bound(smc, &s->temp_end, "U"); if (ret) return ret; s->temp_count = s->temp_end - s->temp_begin; - ret = applesmc_init_index(s); + ret = applesmc_init_index(smc, s); if (ret) return ret; - ret = applesmc_has_key(LIGHT_SENSOR_LEFT_KEY, &left_light_sensor); + ret = applesmc_has_key(smc, LIGHT_SENSOR_LEFT_KEY, &left_light_sensor); if (ret) return ret; - ret = applesmc_has_key(LIGHT_SENSOR_RIGHT_KEY, &right_light_sensor); + ret = applesmc_has_key(smc, LIGHT_SENSOR_RIGHT_KEY, &right_light_sensor); if (ret) return ret; - ret = applesmc_has_key(MOTION_SENSOR_KEY, &s->has_accelerometer); + ret = applesmc_has_key(smc, MOTION_SENSOR_KEY, &s->has_accelerometer); if (ret) return ret; - ret = applesmc_has_key(BACKLIGHT_KEY, &s->has_key_backlight); + ret = applesmc_has_key(smc, BACKLIGHT_KEY, &s->has_key_backlight); if (ret) return ret; @@ -634,13 +896,13 @@ static int applesmc_init_smcreg_try(void) return 0; } -static void applesmc_destroy_smcreg(void) +static void applesmc_destroy_smcreg(struct applesmc_device *smc) { - kfree(smcreg.index); - smcreg.index = NULL; - kfree(smcreg.cache); - smcreg.cache = NULL; - smcreg.init_complete = false; + kfree(smc->reg.index); + smc->reg.index = NULL; + kfree(smc->reg.cache); + smc->reg.cache = NULL; + smc->reg.init_complete = false; } /* @@ -649,12 +911,12 @@ static void applesmc_destroy_smcreg(void) * Retries until initialization is successful, or the operation times out. * */ -static int applesmc_init_smcreg(void) +static int applesmc_init_smcreg(struct applesmc_device *smc) { int ms, ret; for (ms = 0; ms < INIT_TIMEOUT_MSECS; ms += INIT_WAIT_MSECS) { - ret = applesmc_init_smcreg_try(); + ret = applesmc_init_smcreg_try(smc); if (!ret) { if (ms) pr_info("init_smcreg() took %d ms\n", ms); @@ -663,50 +925,223 @@ static int applesmc_init_smcreg(void) msleep(INIT_WAIT_MSECS); } - applesmc_destroy_smcreg(); + applesmc_destroy_smcreg(smc); return ret; } /* Device model stuff */ -static int applesmc_probe(struct platform_device *dev) + +static int applesmc_init_resources(struct applesmc_device *smc); +static void applesmc_free_resources(struct applesmc_device *smc); +static int applesmc_create_modules(struct applesmc_device *smc); +static void applesmc_destroy_modules(struct applesmc_device *smc); + +static int applesmc_add(struct acpi_device *dev) { + struct applesmc_device *smc; int ret; - ret = applesmc_init_smcreg(); + smc = kzalloc(sizeof(struct applesmc_device), GFP_KERNEL); + if (!smc) + return -ENOMEM; + smc->dev = dev; + smc->ldev = &dev->dev; + mutex_init(&smc->reg.mutex); + + dev_set_drvdata(&dev->dev, smc); + + ret = applesmc_init_resources(smc); if (ret) - return ret; + goto out_mem; + + ret = applesmc_init_smcreg(smc); + if (ret) + goto out_res; + + applesmc_device_init(smc); + + ret = applesmc_create_modules(smc); + if (ret) + goto out_reg; + + return 0; + +out_reg: + applesmc_destroy_smcreg(smc); +out_res: + applesmc_free_resources(smc); +out_mem: + dev_set_drvdata(&dev->dev, NULL); + mutex_destroy(&smc->reg.mutex); + kfree(smc); + + return ret; +} + +static void applesmc_remove(struct acpi_device *dev) +{ + struct applesmc_device *smc = dev_get_drvdata(&dev->dev); + + applesmc_destroy_modules(smc); + applesmc_destroy_smcreg(smc); + applesmc_free_resources(smc); - applesmc_device_init(); + mutex_destroy(&smc->reg.mutex); + kfree(smc); + + return; +} + +static acpi_status applesmc_walk_resources(struct acpi_resource *res, + void *data) +{ + struct applesmc_device *smc = data; + + switch (res->type) { + case ACPI_RESOURCE_TYPE_IO: + if (!smc->port_base_set) { + if (res->data.io.address_length < APPLESMC_NR_PORTS) + return AE_ERROR; + smc->port_base = res->data.io.minimum; + smc->port_base_set = true; + } + return AE_OK; + + case ACPI_RESOURCE_TYPE_FIXED_MEMORY32: + if (!smc->iomem_base_set) { + if (res->data.fixed_memory32.address_length < + APPLESMC_IOMEM_MIN_SIZE) { + dev_warn(smc->ldev, "found iomem but it's too small: %u\n", + res->data.fixed_memory32.address_length); + return AE_OK; + } + smc->iomem_base_addr = res->data.fixed_memory32.address; + smc->iomem_base_size = res->data.fixed_memory32.address_length; + smc->iomem_base_set = true; + } + return AE_OK; + + case ACPI_RESOURCE_TYPE_END_TAG: + if (smc->port_base_set) + return AE_OK; + else + return AE_NOT_FOUND; + + default: + return AE_OK; + } +} + +static int applesmc_try_enable_iomem(struct applesmc_device *smc); + +static int applesmc_init_resources(struct applesmc_device *smc) +{ + int ret; + + ret = acpi_walk_resources(smc->dev->handle, METHOD_NAME__CRS, + applesmc_walk_resources, smc); + if (ACPI_FAILURE(ret)) + return -ENXIO; + + if (!request_region(smc->port_base, APPLESMC_NR_PORTS, "applesmc")) + return -ENXIO; + + if (smc->iomem_base_set) { + if (applesmc_try_enable_iomem(smc)) + smc->iomem_base_set = false; + } + + return 0; +} + +static int applesmc_try_enable_iomem(struct applesmc_device *smc) +{ + u8 test_val, ldkn_version; + + dev_dbg(smc->ldev, "Trying to enable iomem based communication\n"); + smc->iomem_base = ioremap(smc->iomem_base_addr, smc->iomem_base_size); + if (!smc->iomem_base) + goto out; + + /* Apple's driver does this check for some reason */ + test_val = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_STATUS); + if (test_val == 0xff) { + dev_warn(smc->ldev, + "iomem enable failed: initial status is 0xff (is %x)\n", + test_val); + goto out_iomem; + } + + if (read_smc(smc, "LDKN", &ldkn_version, 1)) { + dev_warn(smc->ldev, "iomem enable failed: ldkn read failed\n"); + goto out_iomem; + } + + if (ldkn_version < 2) { + dev_warn(smc->ldev, + "iomem enable failed: ldkn version %u is less than minimum (2)\n", + ldkn_version); + goto out_iomem; + } return 0; + +out_iomem: + iounmap(smc->iomem_base); + +out: + return -ENXIO; +} + +static void applesmc_free_resources(struct applesmc_device *smc) +{ + if (smc->iomem_base_set) + iounmap(smc->iomem_base); + release_region(smc->port_base, APPLESMC_NR_PORTS); } /* Synchronize device with memorized backlight state */ static int applesmc_pm_resume(struct device *dev) { - if (smcreg.has_key_backlight) - applesmc_write_key(BACKLIGHT_KEY, backlight_state, 2); + struct applesmc_device *smc = dev_get_drvdata(dev); + + if (smc->reg.has_key_backlight) + applesmc_write_key(smc, BACKLIGHT_KEY, smc->backlight_state, 2); + return 0; } /* Reinitialize device on resume from hibernation */ static int applesmc_pm_restore(struct device *dev) { - applesmc_device_init(); + struct applesmc_device *smc = dev_get_drvdata(dev); + + applesmc_device_init(smc); + return applesmc_pm_resume(dev); } +static const struct acpi_device_id applesmc_ids[] = { + {"APP0001", 0}, + {"", 0}, +}; + static const struct dev_pm_ops applesmc_pm_ops = { .resume = applesmc_pm_resume, .restore = applesmc_pm_restore, }; -static struct platform_driver applesmc_driver = { - .probe = applesmc_probe, - .driver = { - .name = "applesmc", - .pm = &applesmc_pm_ops, +static struct acpi_driver applesmc_driver = { + .name = "applesmc", + .class = "applesmc", + .ids = applesmc_ids, + .ops = { + .add = applesmc_add, + .remove = applesmc_remove + }, + .drv = { + .pm = &applesmc_pm_ops }, }; @@ -714,25 +1149,26 @@ static struct platform_driver applesmc_driver = { * applesmc_calibrate - Set our "resting" values. Callers must * hold applesmc_lock. */ -static void applesmc_calibrate(void) +static void applesmc_calibrate(struct applesmc_device *smc) { - applesmc_read_s16(MOTION_SENSOR_X_KEY, &rest_x); - applesmc_read_s16(MOTION_SENSOR_Y_KEY, &rest_y); - rest_x = -rest_x; + applesmc_read_s16(smc, MOTION_SENSOR_X_KEY, &smc->rest_x); + applesmc_read_s16(smc, MOTION_SENSOR_Y_KEY, &smc->rest_y); + smc->rest_x = -smc->rest_x; } static void applesmc_idev_poll(struct input_dev *idev) { + struct applesmc_device *smc = dev_get_drvdata(&idev->dev); s16 x, y; - if (applesmc_read_s16(MOTION_SENSOR_X_KEY, &x)) + if (applesmc_read_s16(smc, MOTION_SENSOR_X_KEY, &x)) return; - if (applesmc_read_s16(MOTION_SENSOR_Y_KEY, &y)) + if (applesmc_read_s16(smc, MOTION_SENSOR_Y_KEY, &y)) return; x = -x; - input_report_abs(idev, ABS_X, x - rest_x); - input_report_abs(idev, ABS_Y, y - rest_y); + input_report_abs(idev, ABS_X, x - smc->rest_x); + input_report_abs(idev, ABS_Y, y - smc->rest_y); input_sync(idev); } @@ -747,16 +1183,17 @@ static ssize_t applesmc_name_show(struct device *dev, static ssize_t applesmc_position_show(struct device *dev, struct device_attribute *attr, char *buf) { + struct applesmc_device *smc = dev_get_drvdata(dev); int ret; s16 x, y, z; - ret = applesmc_read_s16(MOTION_SENSOR_X_KEY, &x); + ret = applesmc_read_s16(smc, MOTION_SENSOR_X_KEY, &x); if (ret) goto out; - ret = applesmc_read_s16(MOTION_SENSOR_Y_KEY, &y); + ret = applesmc_read_s16(smc, MOTION_SENSOR_Y_KEY, &y); if (ret) goto out; - ret = applesmc_read_s16(MOTION_SENSOR_Z_KEY, &z); + ret = applesmc_read_s16(smc, MOTION_SENSOR_Z_KEY, &z); if (ret) goto out; @@ -770,6 +1207,7 @@ static ssize_t applesmc_position_show(struct device *dev, static ssize_t applesmc_light_show(struct device *dev, struct device_attribute *attr, char *sysfsbuf) { + struct applesmc_device *smc = dev_get_drvdata(dev); const struct applesmc_entry *entry; static int data_length; int ret; @@ -777,7 +1215,7 @@ static ssize_t applesmc_light_show(struct device *dev, u8 buffer[10]; if (!data_length) { - entry = applesmc_get_entry_by_key(LIGHT_SENSOR_LEFT_KEY); + entry = applesmc_get_entry_by_key(smc, LIGHT_SENSOR_LEFT_KEY); if (IS_ERR(entry)) return PTR_ERR(entry); if (entry->len > 10) @@ -786,7 +1224,7 @@ static ssize_t applesmc_light_show(struct device *dev, pr_info("light sensor data length set to %d\n", data_length); } - ret = applesmc_read_key(LIGHT_SENSOR_LEFT_KEY, buffer, data_length); + ret = applesmc_read_key(smc, LIGHT_SENSOR_LEFT_KEY, buffer, data_length); if (ret) goto out; /* newer macbooks report a single 10-bit bigendian value */ @@ -796,7 +1234,7 @@ static ssize_t applesmc_light_show(struct device *dev, } left = buffer[2]; - ret = applesmc_read_key(LIGHT_SENSOR_RIGHT_KEY, buffer, data_length); + ret = applesmc_read_key(smc, LIGHT_SENSOR_RIGHT_KEY, buffer, data_length); if (ret) goto out; right = buffer[2]; @@ -812,7 +1250,8 @@ static ssize_t applesmc_light_show(struct device *dev, static ssize_t applesmc_show_sensor_label(struct device *dev, struct device_attribute *devattr, char *sysfsbuf) { - const char *key = smcreg.index[to_index(devattr)]; + struct applesmc_device *smc = dev_get_drvdata(dev); + const char *key = smc->reg.index[to_index(devattr)]; return sysfs_emit(sysfsbuf, "%s\n", key); } @@ -821,12 +1260,13 @@ static ssize_t applesmc_show_sensor_label(struct device *dev, static ssize_t applesmc_show_temperature(struct device *dev, struct device_attribute *devattr, char *sysfsbuf) { - const char *key = smcreg.index[to_index(devattr)]; + struct applesmc_device *smc = dev_get_drvdata(dev); + const char *key = smc->reg.index[to_index(devattr)]; int ret; s16 value; int temp; - ret = applesmc_read_s16(key, &value); + ret = applesmc_read_s16(smc, key, &value); if (ret) return ret; @@ -838,6 +1278,8 @@ static ssize_t applesmc_show_temperature(struct device *dev, static ssize_t applesmc_show_fan_speed(struct device *dev, struct device_attribute *attr, char *sysfsbuf) { + struct applesmc_device *smc = dev_get_drvdata(dev); + const struct applesmc_entry *entry; int ret; unsigned int speed = 0; char newkey[5]; @@ -846,11 +1288,21 @@ static ssize_t applesmc_show_fan_speed(struct device *dev, scnprintf(newkey, sizeof(newkey), fan_speed_fmt[to_option(attr)], to_index(attr)); - ret = applesmc_read_key(newkey, buffer, 2); + entry = applesmc_get_entry_by_key(smc, newkey); + if (IS_ERR(entry)) + return PTR_ERR(entry); + + if (!strcmp(entry->type, FLOAT_TYPE)) { + ret = applesmc_read_entry(smc, entry, (u8 *) &speed, 4); + speed = applesmc_float_to_u32(speed); + } else { + ret = applesmc_read_entry(smc, entry, buffer, 2); + speed = ((buffer[0] << 8 | buffer[1]) >> 2); + } + if (ret) return ret; - speed = ((buffer[0] << 8 | buffer[1]) >> 2); return sysfs_emit(sysfsbuf, "%u\n", speed); } @@ -858,6 +1310,8 @@ static ssize_t applesmc_store_fan_speed(struct device *dev, struct device_attribute *attr, const char *sysfsbuf, size_t count) { + struct applesmc_device *smc = dev_get_drvdata(dev); + const struct applesmc_entry *entry; int ret; unsigned long speed; char newkey[5]; @@ -869,9 +1323,18 @@ static ssize_t applesmc_store_fan_speed(struct device *dev, scnprintf(newkey, sizeof(newkey), fan_speed_fmt[to_option(attr)], to_index(attr)); - buffer[0] = (speed >> 6) & 0xff; - buffer[1] = (speed << 2) & 0xff; - ret = applesmc_write_key(newkey, buffer, 2); + entry = applesmc_get_entry_by_key(smc, newkey); + if (IS_ERR(entry)) + return PTR_ERR(entry); + + if (!strcmp(entry->type, FLOAT_TYPE)) { + speed = applesmc_u32_to_float(speed); + ret = applesmc_write_entry(smc, entry, (u8 *) &speed, 4); + } else { + buffer[0] = (speed >> 6) & 0xff; + buffer[1] = (speed << 2) & 0xff; + ret = applesmc_write_key(smc, newkey, buffer, 2); + } if (ret) return ret; @@ -882,15 +1345,30 @@ static ssize_t applesmc_store_fan_speed(struct device *dev, static ssize_t applesmc_show_fan_manual(struct device *dev, struct device_attribute *attr, char *sysfsbuf) { + struct applesmc_device *smc = dev_get_drvdata(dev); int ret; u16 manual = 0; u8 buffer[2]; + char newkey[5]; + bool has_newkey = false; + + scnprintf(newkey, sizeof(newkey), FAN_MANUAL_FMT, to_index(attr)); + + ret = applesmc_has_key(smc, newkey, &has_newkey); + if (ret) + return ret; + + if (has_newkey) { + ret = applesmc_read_key(smc, newkey, buffer, 1); + manual = buffer[0]; + } else { + ret = applesmc_read_key(smc, FANS_MANUAL, buffer, 2); + manual = ((buffer[0] << 8 | buffer[1]) >> to_index(attr)) & 0x01; + } - ret = applesmc_read_key(FANS_MANUAL, buffer, 2); if (ret) return ret; - manual = ((buffer[0] << 8 | buffer[1]) >> to_index(attr)) & 0x01; return sysfs_emit(sysfsbuf, "%d\n", manual); } @@ -898,29 +1376,42 @@ static ssize_t applesmc_store_fan_manual(struct device *dev, struct device_attribute *attr, const char *sysfsbuf, size_t count) { + struct applesmc_device *smc = dev_get_drvdata(dev); int ret; u8 buffer[2]; + char newkey[5]; + bool has_newkey = false; unsigned long input; u16 val; if (kstrtoul(sysfsbuf, 10, &input) < 0) return -EINVAL; - ret = applesmc_read_key(FANS_MANUAL, buffer, 2); + scnprintf(newkey, sizeof(newkey), FAN_MANUAL_FMT, to_index(attr)); + + ret = applesmc_has_key(smc, newkey, &has_newkey); if (ret) - goto out; + return ret; - val = (buffer[0] << 8 | buffer[1]); + if (has_newkey) { + buffer[0] = input & 1; + ret = applesmc_write_key(smc, newkey, buffer, 1); + } else { + ret = applesmc_read_key(smc, FANS_MANUAL, buffer, 2); + val = (buffer[0] << 8 | buffer[1]); + if (ret) + goto out; - if (input) - val = val | (0x01 << to_index(attr)); - else - val = val & ~(0x01 << to_index(attr)); + if (input) + val = val | (0x01 << to_index(attr)); + else + val = val & ~(0x01 << to_index(attr)); - buffer[0] = (val >> 8) & 0xFF; - buffer[1] = val & 0xFF; + buffer[0] = (val >> 8) & 0xFF; + buffer[1] = val & 0xFF; - ret = applesmc_write_key(FANS_MANUAL, buffer, 2); + ret = applesmc_write_key(smc, FANS_MANUAL, buffer, 2); + } out: if (ret) @@ -932,13 +1423,14 @@ static ssize_t applesmc_store_fan_manual(struct device *dev, static ssize_t applesmc_show_fan_position(struct device *dev, struct device_attribute *attr, char *sysfsbuf) { + struct applesmc_device *smc = dev_get_drvdata(dev); int ret; char newkey[5]; u8 buffer[17]; scnprintf(newkey, sizeof(newkey), FAN_ID_FMT, to_index(attr)); - ret = applesmc_read_key(newkey, buffer, 16); + ret = applesmc_read_key(smc, newkey, buffer, 16); buffer[16] = 0; if (ret) @@ -950,43 +1442,79 @@ static ssize_t applesmc_show_fan_position(struct device *dev, static ssize_t applesmc_calibrate_show(struct device *dev, struct device_attribute *attr, char *sysfsbuf) { - return sysfs_emit(sysfsbuf, "(%d,%d)\n", rest_x, rest_y); + struct applesmc_device *smc = dev_get_drvdata(dev); + + return sysfs_emit(sysfsbuf, "(%d,%d)\n", smc->rest_x, smc->rest_y); } static ssize_t applesmc_calibrate_store(struct device *dev, struct device_attribute *attr, const char *sysfsbuf, size_t count) { - applesmc_calibrate(); + struct applesmc_device *smc = dev_get_drvdata(dev); + + applesmc_calibrate(smc); return count; } static void applesmc_backlight_set(struct work_struct *work) { - applesmc_write_key(BACKLIGHT_KEY, backlight_state, 2); + struct applesmc_device *smc = container_of(work, struct applesmc_device, backlight_work); + + applesmc_write_key(smc, BACKLIGHT_KEY, smc->backlight_state, 2); } -static DECLARE_WORK(backlight_work, &applesmc_backlight_set); static void applesmc_brightness_set(struct led_classdev *led_cdev, enum led_brightness value) { + struct applesmc_device *smc = dev_get_drvdata(led_cdev->dev); int ret; - backlight_state[0] = value; - ret = queue_work(applesmc_led_wq, &backlight_work); + smc->backlight_state[0] = value; + ret = queue_work(smc->backlight_wq, &smc->backlight_work); if (debug && (!ret)) dev_dbg(led_cdev->dev, "work was already on the queue.\n"); } +static ssize_t applesmc_BCLM_store(struct device *dev, + struct device_attribute *attr, char *sysfsbuf, size_t count) +{ + struct applesmc_device *smc = dev_get_drvdata(dev); + u8 val; + + if (kstrtou8(sysfsbuf, 10, &val) < 0) + return -EINVAL; + + if (val < 0 || val > 100) + return -EINVAL; + + if (applesmc_write_key(smc, "BCLM", &val, 1)) + return -ENODEV; + return count; +} + +static ssize_t applesmc_BCLM_show(struct device *dev, + struct device_attribute *attr, char *sysfsbuf) +{ + struct applesmc_device *smc = dev_get_drvdata(dev); + u8 val; + + if (applesmc_read_key(smc, "BCLM", &val, 1)) + return -ENODEV; + + return sysfs_emit(sysfsbuf, "%d\n", val); +} + static ssize_t applesmc_key_count_show(struct device *dev, struct device_attribute *attr, char *sysfsbuf) { + struct applesmc_device *smc = dev_get_drvdata(dev); int ret; u8 buffer[4]; u32 count; - ret = applesmc_read_key(KEY_COUNT_KEY, buffer, 4); + ret = applesmc_read_key(smc, KEY_COUNT_KEY, buffer, 4); if (ret) return ret; @@ -998,13 +1526,14 @@ static ssize_t applesmc_key_count_show(struct device *dev, static ssize_t applesmc_key_at_index_read_show(struct device *dev, struct device_attribute *attr, char *sysfsbuf) { + struct applesmc_device *smc = dev_get_drvdata(dev); const struct applesmc_entry *entry; int ret; - entry = applesmc_get_entry_by_index(key_at_index); + entry = applesmc_get_entry_by_index(smc, smc->key_at_index); if (IS_ERR(entry)) return PTR_ERR(entry); - ret = applesmc_read_entry(entry, sysfsbuf, entry->len); + ret = applesmc_read_entry(smc, entry, sysfsbuf, entry->len); if (ret) return ret; @@ -1014,9 +1543,10 @@ static ssize_t applesmc_key_at_index_read_show(struct device *dev, static ssize_t applesmc_key_at_index_data_length_show(struct device *dev, struct device_attribute *attr, char *sysfsbuf) { + struct applesmc_device *smc = dev_get_drvdata(dev); const struct applesmc_entry *entry; - entry = applesmc_get_entry_by_index(key_at_index); + entry = applesmc_get_entry_by_index(smc, smc->key_at_index); if (IS_ERR(entry)) return PTR_ERR(entry); @@ -1026,9 +1556,10 @@ static ssize_t applesmc_key_at_index_data_length_show(struct device *dev, static ssize_t applesmc_key_at_index_type_show(struct device *dev, struct device_attribute *attr, char *sysfsbuf) { + struct applesmc_device *smc = dev_get_drvdata(dev); const struct applesmc_entry *entry; - entry = applesmc_get_entry_by_index(key_at_index); + entry = applesmc_get_entry_by_index(smc, smc->key_at_index); if (IS_ERR(entry)) return PTR_ERR(entry); @@ -1038,9 +1569,10 @@ static ssize_t applesmc_key_at_index_type_show(struct device *dev, static ssize_t applesmc_key_at_index_name_show(struct device *dev, struct device_attribute *attr, char *sysfsbuf) { + struct applesmc_device *smc = dev_get_drvdata(dev); const struct applesmc_entry *entry; - entry = applesmc_get_entry_by_index(key_at_index); + entry = applesmc_get_entry_by_index(smc, smc->key_at_index); if (IS_ERR(entry)) return PTR_ERR(entry); @@ -1050,28 +1582,25 @@ static ssize_t applesmc_key_at_index_name_show(struct device *dev, static ssize_t applesmc_key_at_index_show(struct device *dev, struct device_attribute *attr, char *sysfsbuf) { - return sysfs_emit(sysfsbuf, "%d\n", key_at_index); + struct applesmc_device *smc = dev_get_drvdata(dev); + + return sysfs_emit(sysfsbuf, "%d\n", smc->key_at_index); } static ssize_t applesmc_key_at_index_store(struct device *dev, struct device_attribute *attr, const char *sysfsbuf, size_t count) { + struct applesmc_device *smc = dev_get_drvdata(dev); unsigned long newkey; if (kstrtoul(sysfsbuf, 10, &newkey) < 0 - || newkey >= smcreg.key_count) + || newkey >= smc->reg.key_count) return -EINVAL; - key_at_index = newkey; + smc->key_at_index = newkey; return count; } -static struct led_classdev applesmc_backlight = { - .name = "smc::kbd_backlight", - .default_trigger = "nand-disk", - .brightness_set = applesmc_brightness_set, -}; - static struct applesmc_node_group info_group[] = { { "name", applesmc_name_show }, { "key_count", applesmc_key_count_show }, @@ -1111,19 +1640,25 @@ static struct applesmc_node_group temp_group[] = { { } }; +static struct applesmc_node_group BCLM_group[] = { + { "battery_charge_limit", applesmc_BCLM_show, applesmc_BCLM_store }, + { } +}; + /* Module stuff */ /* * applesmc_destroy_nodes - remove files and free associated memory */ -static void applesmc_destroy_nodes(struct applesmc_node_group *groups) +static void applesmc_destroy_nodes(struct applesmc_device *smc, + struct applesmc_node_group *groups) { struct applesmc_node_group *grp; struct applesmc_dev_attr *node; for (grp = groups; grp->nodes; grp++) { for (node = grp->nodes; node->sda.dev_attr.attr.name; node++) - sysfs_remove_file(&pdev->dev.kobj, + sysfs_remove_file(&smc->dev->dev.kobj, &node->sda.dev_attr.attr); kfree(grp->nodes); grp->nodes = NULL; @@ -1133,7 +1668,8 @@ static void applesmc_destroy_nodes(struct applesmc_node_group *groups) /* * applesmc_create_nodes - create a two-dimensional group of sysfs files */ -static int applesmc_create_nodes(struct applesmc_node_group *groups, int num) +static int applesmc_create_nodes(struct applesmc_device *smc, + struct applesmc_node_group *groups, int num) { struct applesmc_node_group *grp; struct applesmc_dev_attr *node; @@ -1157,7 +1693,7 @@ static int applesmc_create_nodes(struct applesmc_node_group *groups, int num) sysfs_attr_init(attr); attr->name = node->name; attr->mode = 0444 | (grp->store ? 0200 : 0); - ret = sysfs_create_file(&pdev->dev.kobj, attr); + ret = sysfs_create_file(&smc->dev->dev.kobj, attr); if (ret) { attr->name = NULL; goto out; @@ -1167,57 +1703,56 @@ static int applesmc_create_nodes(struct applesmc_node_group *groups, int num) return 0; out: - applesmc_destroy_nodes(groups); + applesmc_destroy_nodes(smc, groups); return ret; } /* Create accelerometer resources */ -static int applesmc_create_accelerometer(void) +static int applesmc_create_accelerometer(struct applesmc_device *smc) { int ret; - - if (!smcreg.has_accelerometer) + if (!smc->reg.has_accelerometer) return 0; - ret = applesmc_create_nodes(accelerometer_group, 1); + ret = applesmc_create_nodes(smc, accelerometer_group, 1); if (ret) goto out; - applesmc_idev = input_allocate_device(); - if (!applesmc_idev) { + smc->idev = input_allocate_device(); + if (!smc->idev) { ret = -ENOMEM; goto out_sysfs; } /* initial calibrate for the input device */ - applesmc_calibrate(); + applesmc_calibrate(smc); /* initialize the input device */ - applesmc_idev->name = "applesmc"; - applesmc_idev->id.bustype = BUS_HOST; - applesmc_idev->dev.parent = &pdev->dev; - input_set_abs_params(applesmc_idev, ABS_X, + smc->idev->name = "applesmc"; + smc->idev->id.bustype = BUS_HOST; + smc->idev->dev.parent = &smc->dev->dev; + input_set_abs_params(smc->idev, ABS_X, -256, 256, APPLESMC_INPUT_FUZZ, APPLESMC_INPUT_FLAT); - input_set_abs_params(applesmc_idev, ABS_Y, + input_set_abs_params(smc->idev, ABS_Y, -256, 256, APPLESMC_INPUT_FUZZ, APPLESMC_INPUT_FLAT); - ret = input_setup_polling(applesmc_idev, applesmc_idev_poll); + ret = input_setup_polling(smc->idev, applesmc_idev_poll); if (ret) goto out_idev; - input_set_poll_interval(applesmc_idev, APPLESMC_POLL_INTERVAL); + input_set_poll_interval(smc->idev, APPLESMC_POLL_INTERVAL); - ret = input_register_device(applesmc_idev); + ret = input_register_device(smc->idev); if (ret) goto out_idev; return 0; out_idev: - input_free_device(applesmc_idev); + input_free_device(smc->idev); out_sysfs: - applesmc_destroy_nodes(accelerometer_group); + applesmc_destroy_nodes(smc, accelerometer_group); out: pr_warn("driver init failed (ret=%d)!\n", ret); @@ -1225,44 +1760,55 @@ static int applesmc_create_accelerometer(void) } /* Release all resources used by the accelerometer */ -static void applesmc_release_accelerometer(void) +static void applesmc_release_accelerometer(struct applesmc_device *smc) { - if (!smcreg.has_accelerometer) + if (!smc->reg.has_accelerometer) return; - input_unregister_device(applesmc_idev); - applesmc_destroy_nodes(accelerometer_group); + input_unregister_device(smc->idev); + applesmc_destroy_nodes(smc, accelerometer_group); } -static int applesmc_create_light_sensor(void) +static int applesmc_create_light_sensor(struct applesmc_device *smc) { - if (!smcreg.num_light_sensors) + if (!smc->reg.num_light_sensors) return 0; - return applesmc_create_nodes(light_sensor_group, 1); + return applesmc_create_nodes(smc, light_sensor_group, 1); } -static void applesmc_release_light_sensor(void) +static void applesmc_release_light_sensor(struct applesmc_device *smc) { - if (!smcreg.num_light_sensors) + if (!smc->reg.num_light_sensors) return; - applesmc_destroy_nodes(light_sensor_group); + applesmc_destroy_nodes(smc, light_sensor_group); } -static int applesmc_create_key_backlight(void) +static int applesmc_create_key_backlight(struct applesmc_device *smc) { - if (!smcreg.has_key_backlight) + int ret; + + if (!smc->reg.has_key_backlight) return 0; - applesmc_led_wq = create_singlethread_workqueue("applesmc-led"); - if (!applesmc_led_wq) + smc->backlight_wq = create_singlethread_workqueue("applesmc-led"); + if (!smc->backlight_wq) return -ENOMEM; - return led_classdev_register(&pdev->dev, &applesmc_backlight); + + INIT_WORK(&smc->backlight_work, applesmc_backlight_set); + smc->backlight_dev.name = "smc::kbd_backlight"; + smc->backlight_dev.default_trigger = "nand-disk"; + smc->backlight_dev.brightness_set = applesmc_brightness_set; + ret = led_classdev_register(&smc->dev->dev, &smc->backlight_dev); + if (ret) + destroy_workqueue(smc->backlight_wq); + + return ret; } -static void applesmc_release_key_backlight(void) +static void applesmc_release_key_backlight(struct applesmc_device *smc) { - if (!smcreg.has_key_backlight) + if (!smc->reg.has_key_backlight) return; - led_classdev_unregister(&applesmc_backlight); - destroy_workqueue(applesmc_led_wq); + led_classdev_unregister(&smc->backlight_dev); + destroy_workqueue(smc->backlight_wq); } static int applesmc_dmi_match(const struct dmi_system_id *id) @@ -1291,6 +1837,10 @@ static const struct dmi_system_id applesmc_whitelist[] __initconst = { DMI_MATCH(DMI_BOARD_VENDOR, "Apple"), DMI_MATCH(DMI_PRODUCT_NAME, "Macmini") }, }, + { applesmc_dmi_match, "Apple iMacPro", { + DMI_MATCH(DMI_BOARD_VENDOR, "Apple"), + DMI_MATCH(DMI_PRODUCT_NAME, "iMacPro") }, + }, { applesmc_dmi_match, "Apple MacPro", { DMI_MATCH(DMI_BOARD_VENDOR, "Apple"), DMI_MATCH(DMI_PRODUCT_NAME, "MacPro") }, @@ -1306,90 +1856,91 @@ static const struct dmi_system_id applesmc_whitelist[] __initconst = { { .ident = NULL } }; -static int __init applesmc_init(void) +static int applesmc_create_modules(struct applesmc_device *smc) { int ret; - if (!dmi_check_system(applesmc_whitelist)) { - pr_warn("supported laptop not found!\n"); - ret = -ENODEV; - goto out; - } - - if (!request_region(APPLESMC_DATA_PORT, APPLESMC_NR_PORTS, - "applesmc")) { - ret = -ENXIO; - goto out; - } - - ret = platform_driver_register(&applesmc_driver); - if (ret) - goto out_region; - - pdev = platform_device_register_simple("applesmc", APPLESMC_DATA_PORT, - NULL, 0); - if (IS_ERR(pdev)) { - ret = PTR_ERR(pdev); - goto out_driver; - } - - /* create register cache */ - ret = applesmc_init_smcreg(); + ret = applesmc_create_nodes(smc, info_group, 1); if (ret) - goto out_device; - - ret = applesmc_create_nodes(info_group, 1); + goto out; + ret = applesmc_create_nodes(smc, BCLM_group, 1); if (ret) - goto out_smcreg; + goto out_info; - ret = applesmc_create_nodes(fan_group, smcreg.fan_count); + ret = applesmc_create_nodes(smc, fan_group, smc->reg.fan_count); if (ret) - goto out_info; + goto out_bclm; - ret = applesmc_create_nodes(temp_group, smcreg.index_count); + ret = applesmc_create_nodes(smc, temp_group, smc->reg.index_count); if (ret) goto out_fans; - ret = applesmc_create_accelerometer(); + ret = applesmc_create_accelerometer(smc); if (ret) goto out_temperature; - ret = applesmc_create_light_sensor(); + ret = applesmc_create_light_sensor(smc); if (ret) goto out_accelerometer; - ret = applesmc_create_key_backlight(); + ret = applesmc_create_key_backlight(smc); if (ret) goto out_light_sysfs; - hwmon_dev = hwmon_device_register(&pdev->dev); - if (IS_ERR(hwmon_dev)) { - ret = PTR_ERR(hwmon_dev); + smc->hwmon_dev = hwmon_device_register(&smc->dev->dev); + if (IS_ERR(smc->hwmon_dev)) { + ret = PTR_ERR(smc->hwmon_dev); goto out_light_ledclass; } return 0; out_light_ledclass: - applesmc_release_key_backlight(); + applesmc_release_key_backlight(smc); out_light_sysfs: - applesmc_release_light_sensor(); + applesmc_release_light_sensor(smc); out_accelerometer: - applesmc_release_accelerometer(); + applesmc_release_accelerometer(smc); out_temperature: - applesmc_destroy_nodes(temp_group); + applesmc_destroy_nodes(smc, temp_group); out_fans: - applesmc_destroy_nodes(fan_group); + applesmc_destroy_nodes(smc, fan_group); +out_bclm: + applesmc_destroy_nodes(smc, BCLM_group); out_info: - applesmc_destroy_nodes(info_group); -out_smcreg: - applesmc_destroy_smcreg(); -out_device: - platform_device_unregister(pdev); -out_driver: - platform_driver_unregister(&applesmc_driver); -out_region: - release_region(APPLESMC_DATA_PORT, APPLESMC_NR_PORTS); + applesmc_destroy_nodes(smc, info_group); +out: + return ret; +} + +static void applesmc_destroy_modules(struct applesmc_device *smc) +{ + hwmon_device_unregister(smc->hwmon_dev); + applesmc_release_key_backlight(smc); + applesmc_release_light_sensor(smc); + applesmc_release_accelerometer(smc); + applesmc_destroy_nodes(smc, temp_group); + applesmc_destroy_nodes(smc, fan_group); + applesmc_destroy_nodes(smc, BCLM_group); + applesmc_destroy_nodes(smc, info_group); +} + +static int __init applesmc_init(void) +{ + int ret; + + if (!dmi_check_system(applesmc_whitelist)) { + pr_warn("supported laptop not found!\n"); + ret = -ENODEV; + goto out; + } + + ret = acpi_bus_register_driver(&applesmc_driver); + if (ret) + goto out; + + return 0; + out: pr_warn("driver init failed (ret=%d)!\n", ret); return ret; @@ -1397,23 +1948,14 @@ static int __init applesmc_init(void) static void __exit applesmc_exit(void) { - hwmon_device_unregister(hwmon_dev); - applesmc_release_key_backlight(); - applesmc_release_light_sensor(); - applesmc_release_accelerometer(); - applesmc_destroy_nodes(temp_group); - applesmc_destroy_nodes(fan_group); - applesmc_destroy_nodes(info_group); - applesmc_destroy_smcreg(); - platform_device_unregister(pdev); - platform_driver_unregister(&applesmc_driver); - release_region(APPLESMC_DATA_PORT, APPLESMC_NR_PORTS); + acpi_bus_unregister_driver(&applesmc_driver); } module_init(applesmc_init); module_exit(applesmc_exit); MODULE_AUTHOR("Nicolas Boichat"); +MODULE_AUTHOR("Paul Pawlowski"); MODULE_DESCRIPTION("Apple SMC"); MODULE_LICENSE("GPL v2"); MODULE_DEVICE_TABLE(dmi, applesmc_whitelist); diff --git a/drivers/nvme/host/apple.c b/drivers/nvme/host/apple.c index 8971aca41e63..6cf0e3cc9682 100644 --- a/drivers/nvme/host/apple.c +++ b/drivers/nvme/host/apple.c @@ -221,7 +221,7 @@ static unsigned int apple_nvme_queue_depth(struct apple_nvme_queue *q) return APPLE_ANS_MAX_QUEUE_DEPTH; } -static void apple_nvme_rtkit_crashed(void *cookie) +static void apple_nvme_rtkit_crashed(void *cookie, const void *crashlog, size_t crashlog_size) { struct apple_nvme *anv = cookie; diff --git a/drivers/pci/vgaarb.c b/drivers/pci/vgaarb.c index 78748e8d2dba..2b2b558cebe6 100644 --- a/drivers/pci/vgaarb.c +++ b/drivers/pci/vgaarb.c @@ -143,6 +143,7 @@ void vga_set_default_device(struct pci_dev *pdev) pci_dev_put(vga_default); vga_default = pci_dev_get(pdev); } +EXPORT_SYMBOL_GPL(vga_set_default_device); /** * vga_remove_vgacon - deactivate VGA console diff --git a/drivers/platform/x86/apple-gmux.c b/drivers/platform/x86/apple-gmux.c index 1417e230edbd..e69785af8e1d 100644 --- a/drivers/platform/x86/apple-gmux.c +++ b/drivers/platform/x86/apple-gmux.c @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -107,6 +108,10 @@ struct apple_gmux_config { # define MMIO_GMUX_MAX_BRIGHTNESS 0xffff +static bool force_igd; +module_param(force_igd, bool, 0); +MODULE_PARM_DESC(force_idg, "Switch gpu to igd on module load. Make sure that you have apple-set-os set up and the iGPU is in `lspci -s 00:02.0`. (default: false) (bool)"); + static u8 gmux_pio_read8(struct apple_gmux_data *gmux_data, int port) { return inb(gmux_data->iostart + port); @@ -945,6 +950,19 @@ static int gmux_probe(struct pnp_dev *pnp, const struct pnp_device_id *id) gmux_enable_interrupts(gmux_data); gmux_read_switch_state(gmux_data); + if (force_igd) { + struct pci_dev *pdev; + + pdev = pci_get_domain_bus_and_slot(0, 0, PCI_DEVFN(2, 0)); + if (pdev) { + pr_info("Switching to IGD"); + gmux_switchto(VGA_SWITCHEROO_IGD); + vga_set_default_device(pdev); + } else { + pr_err("force_idg is true, but couldn't find iGPU at 00:02.0! Is apple-set-os working?"); + } + } + /* * Retina MacBook Pros cannot switch the panel's AUX separately * and need eDP pre-calibration. They are distinguishable from diff --git a/drivers/soc/apple/Kconfig b/drivers/soc/apple/Kconfig index 6388cbe1e56b..50f092732796 100644 --- a/drivers/soc/apple/Kconfig +++ b/drivers/soc/apple/Kconfig @@ -4,6 +4,16 @@ if ARCH_APPLE || COMPILE_TEST menu "Apple SoC drivers" +config APPLE_DOCKCHANNEL + tristate "Apple DockChannel FIFO" + depends on ARCH_APPLE || COMPILE_TEST + default ARCH_APPLE + help + DockChannel is a simple FIFO used on Apple SoCs for debug and inter-processor + communications. + + Say 'y' here if you have an Apple SoC. + config APPLE_MAILBOX tristate "Apple SoC mailboxes" depends on PM @@ -30,6 +40,20 @@ config APPLE_RTKIT Say 'y' here if you have an Apple SoC. +config APPLE_RTKIT_HELPER + tristate "Apple Generic RTKit helper co-processor" + depends on APPLE_RTKIT + depends on ARCH_APPLE || COMPILE_TEST + default ARCH_APPLE + help + Apple SoCs such as the M1 come with various co-processors running + their proprietary RTKit operating system. This option enables support + for a generic co-processor that does not implement any additional + in-band communications. It can be used for testing purposes, or for + coprocessors such as MTP that communicate over a different interface. + + Say 'y' here if you have an Apple SoC. + config APPLE_SART tristate "Apple SART DMA address filter" depends on ARCH_APPLE || COMPILE_TEST diff --git a/drivers/soc/apple/Makefile b/drivers/soc/apple/Makefile index 4d9ab8f3037b..5e526a9edcf2 100644 --- a/drivers/soc/apple/Makefile +++ b/drivers/soc/apple/Makefile @@ -1,10 +1,16 @@ # SPDX-License-Identifier: GPL-2.0-only +obj-$(CONFIG_APPLE_DOCKCHANNEL) += apple-dockchannel.o +apple-dockchannel-y = dockchannel.o + obj-$(CONFIG_APPLE_MAILBOX) += apple-mailbox.o apple-mailbox-y = mailbox.o obj-$(CONFIG_APPLE_RTKIT) += apple-rtkit.o apple-rtkit-y = rtkit.o rtkit-crashlog.o +obj-$(CONFIG_APPLE_RTKIT_HELPER) += apple-rtkit-helper.o +apple-rtkit-helper-y = rtkit-helper.o + obj-$(CONFIG_APPLE_SART) += apple-sart.o apple-sart-y = sart.o diff --git a/drivers/soc/apple/dockchannel.c b/drivers/soc/apple/dockchannel.c new file mode 100644 index 000000000000..3a0d7964007c --- /dev/null +++ b/drivers/soc/apple/dockchannel.c @@ -0,0 +1,406 @@ +// SPDX-License-Identifier: GPL-2.0-only OR MIT +/* + * Apple DockChannel FIFO driver + * Copyright The Asahi Linux Contributors + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define DOCKCHANNEL_MAX_IRQ 32 + +#define DOCKCHANNEL_TX_TIMEOUT_MS 1000 +#define DOCKCHANNEL_RX_TIMEOUT_MS 1000 + +#define IRQ_MASK 0x0 +#define IRQ_FLAG 0x4 + +#define IRQ_TX BIT(0) +#define IRQ_RX BIT(1) + +#define CONFIG_TX_THRESH 0x0 +#define CONFIG_RX_THRESH 0x4 + +#define DATA_TX8 0x4 +#define DATA_TX16 0x8 +#define DATA_TX24 0xc +#define DATA_TX32 0x10 +#define DATA_TX_FREE 0x14 +#define DATA_RX8 0x1c +#define DATA_RX16 0x20 +#define DATA_RX24 0x24 +#define DATA_RX32 0x28 +#define DATA_RX_COUNT 0x2c + +struct dockchannel { + struct device *dev; + int tx_irq; + int rx_irq; + + void __iomem *config_base; + void __iomem *data_base; + + u32 fifo_size; + bool awaiting; + struct completion tx_comp; + struct completion rx_comp; + + void *cookie; + void (*data_available)(void *cookie, size_t avail); +}; + +struct dockchannel_common { + struct device *dev; + struct irq_domain *domain; + int irq; + + void __iomem *irq_base; +}; + +/* Dockchannel FIFO functions */ + +static irqreturn_t dockchannel_tx_irq(int irq, void *data) +{ + struct dockchannel *dockchannel = data; + + disable_irq_nosync(irq); + complete(&dockchannel->tx_comp); + + return IRQ_HANDLED; +} + +static irqreturn_t dockchannel_rx_irq(int irq, void *data) +{ + struct dockchannel *dockchannel = data; + + disable_irq_nosync(irq); + + if (dockchannel->awaiting) { + return IRQ_WAKE_THREAD; + } else { + complete(&dockchannel->rx_comp); + return IRQ_HANDLED; + } +} + +static irqreturn_t dockchannel_rx_irq_thread(int irq, void *data) +{ + struct dockchannel *dockchannel = data; + size_t avail = readl_relaxed(dockchannel->data_base + DATA_RX_COUNT); + + dockchannel->awaiting = false; + dockchannel->data_available(dockchannel->cookie, avail); + + return IRQ_HANDLED; +} + +int dockchannel_send(struct dockchannel *dockchannel, const void *buf, size_t count) +{ + size_t left = count; + const u8 *p = buf; + + while (left > 0) { + size_t avail = readl_relaxed(dockchannel->data_base + DATA_TX_FREE); + size_t block = min(left, avail); + + if (avail == 0) { + size_t threshold = min((size_t)(dockchannel->fifo_size / 2), left); + + writel_relaxed(threshold, dockchannel->config_base + CONFIG_TX_THRESH); + reinit_completion(&dockchannel->tx_comp); + enable_irq(dockchannel->tx_irq); + + if (!wait_for_completion_timeout(&dockchannel->tx_comp, + msecs_to_jiffies(DOCKCHANNEL_TX_TIMEOUT_MS))) { + disable_irq(dockchannel->tx_irq); + return -ETIMEDOUT; + } + + continue; + } + + while (block >= 4) { + writel_relaxed(get_unaligned_le32(p), dockchannel->data_base + DATA_TX32); + p += 4; + left -= 4; + block -= 4; + } + while (block > 0) { + writeb_relaxed(*p++, dockchannel->data_base + DATA_TX8); + left--; + block--; + } + } + + return count; +} +EXPORT_SYMBOL(dockchannel_send); + +int dockchannel_recv(struct dockchannel *dockchannel, void *buf, size_t count) +{ + size_t left = count; + u8 *p = buf; + + while (left > 0) { + size_t avail = readl_relaxed(dockchannel->data_base + DATA_RX_COUNT); + size_t block = min(left, avail); + + if (avail == 0) { + size_t threshold = min((size_t)(dockchannel->fifo_size / 2), left); + + writel_relaxed(threshold, dockchannel->config_base + CONFIG_RX_THRESH); + reinit_completion(&dockchannel->rx_comp); + enable_irq(dockchannel->rx_irq); + + if (!wait_for_completion_timeout(&dockchannel->rx_comp, + msecs_to_jiffies(DOCKCHANNEL_RX_TIMEOUT_MS))) { + disable_irq(dockchannel->rx_irq); + return -ETIMEDOUT; + } + + continue; + } + + while (block >= 4) { + put_unaligned_le32(readl_relaxed(dockchannel->data_base + DATA_RX32), p); + p += 4; + left -= 4; + block -= 4; + } + while (block > 0) { + *p++ = readl_relaxed(dockchannel->data_base + DATA_RX8) >> 8; + left--; + block--; + } + } + + return count; +} +EXPORT_SYMBOL(dockchannel_recv); + +int dockchannel_await(struct dockchannel *dockchannel, + void (*callback)(void *cookie, size_t avail), + void *cookie, size_t count) +{ + size_t threshold = min((size_t)dockchannel->fifo_size, count); + + if (!count) { + dockchannel->awaiting = false; + disable_irq(dockchannel->rx_irq); + return 0; + } + + dockchannel->data_available = callback; + dockchannel->cookie = cookie; + dockchannel->awaiting = true; + writel_relaxed(threshold, dockchannel->config_base + CONFIG_RX_THRESH); + enable_irq(dockchannel->rx_irq); + + return threshold; +} +EXPORT_SYMBOL(dockchannel_await); + +struct dockchannel *dockchannel_init(struct platform_device *pdev) +{ + struct device *dev = &pdev->dev; + struct dockchannel *dockchannel; + int ret; + + dockchannel = devm_kzalloc(dev, sizeof(*dockchannel), GFP_KERNEL); + if (!dockchannel) + return ERR_PTR(-ENOMEM); + + dockchannel->dev = dev; + dockchannel->config_base = devm_platform_ioremap_resource_byname(pdev, "config"); + if (IS_ERR(dockchannel->config_base)) + return (__force void *)dockchannel->config_base; + + dockchannel->data_base = devm_platform_ioremap_resource_byname(pdev, "data"); + if (IS_ERR(dockchannel->data_base)) + return (__force void *)dockchannel->data_base; + + ret = of_property_read_u32(dev->of_node, "apple,fifo-size", &dockchannel->fifo_size); + if (ret) + return ERR_PTR(dev_err_probe(dev, ret, "Missing apple,fifo-size property")); + + init_completion(&dockchannel->tx_comp); + init_completion(&dockchannel->rx_comp); + + dockchannel->tx_irq = platform_get_irq_byname(pdev, "tx"); + if (dockchannel->tx_irq <= 0) { + return ERR_PTR(dev_err_probe(dev, dockchannel->tx_irq, + "Failed to get TX IRQ")); + } + + dockchannel->rx_irq = platform_get_irq_byname(pdev, "rx"); + if (dockchannel->rx_irq <= 0) { + return ERR_PTR(dev_err_probe(dev, dockchannel->rx_irq, + "Failed to get RX IRQ")); + } + + ret = devm_request_irq(dev, dockchannel->tx_irq, dockchannel_tx_irq, IRQF_NO_AUTOEN, + "apple-dockchannel-tx", dockchannel); + if (ret) + return ERR_PTR(dev_err_probe(dev, ret, "Failed to request TX IRQ")); + + ret = devm_request_threaded_irq(dev, dockchannel->rx_irq, dockchannel_rx_irq, + dockchannel_rx_irq_thread, IRQF_NO_AUTOEN, + "apple-dockchannel-rx", dockchannel); + if (ret) + return ERR_PTR(dev_err_probe(dev, ret, "Failed to request RX IRQ")); + + return dockchannel; +} +EXPORT_SYMBOL(dockchannel_init); + + +/* Dockchannel IRQchip */ + +static void dockchannel_irq(struct irq_desc *desc) +{ + unsigned int irq = irq_desc_get_irq(desc); + struct irq_chip *chip = irq_desc_get_chip(desc); + struct dockchannel_common *dcc = irq_get_handler_data(irq); + unsigned long flags = readl_relaxed(dcc->irq_base + IRQ_FLAG); + int bit; + + chained_irq_enter(chip, desc); + + for_each_set_bit(bit, &flags, DOCKCHANNEL_MAX_IRQ) + generic_handle_domain_irq(dcc->domain, bit); + + chained_irq_exit(chip, desc); +} + +static void dockchannel_irq_ack(struct irq_data *data) +{ + struct dockchannel_common *dcc = irq_data_get_irq_chip_data(data); + unsigned int hwirq = data->hwirq; + + writel_relaxed(BIT(hwirq), dcc->irq_base + IRQ_FLAG); +} + +static void dockchannel_irq_mask(struct irq_data *data) +{ + struct dockchannel_common *dcc = irq_data_get_irq_chip_data(data); + unsigned int hwirq = data->hwirq; + u32 val = readl_relaxed(dcc->irq_base + IRQ_MASK); + + writel_relaxed(val & ~BIT(hwirq), dcc->irq_base + IRQ_MASK); +} + +static void dockchannel_irq_unmask(struct irq_data *data) +{ + struct dockchannel_common *dcc = irq_data_get_irq_chip_data(data); + unsigned int hwirq = data->hwirq; + u32 val = readl_relaxed(dcc->irq_base + IRQ_MASK); + + writel_relaxed(val | BIT(hwirq), dcc->irq_base + IRQ_MASK); +} + +static const struct irq_chip dockchannel_irqchip = { + .name = "dockchannel-irqc", + .irq_ack = dockchannel_irq_ack, + .irq_mask = dockchannel_irq_mask, + .irq_unmask = dockchannel_irq_unmask, +}; + +static int dockchannel_irq_domain_map(struct irq_domain *d, unsigned int virq, + irq_hw_number_t hw) +{ + irq_set_chip_data(virq, d->host_data); + irq_set_chip_and_handler(virq, &dockchannel_irqchip, handle_level_irq); + + return 0; +} + +static const struct irq_domain_ops dockchannel_irq_domain_ops = { + .xlate = irq_domain_xlate_twocell, + .map = dockchannel_irq_domain_map, +}; + +static int dockchannel_probe(struct platform_device *pdev) +{ + struct device *dev = &pdev->dev; + struct dockchannel_common *dcc; + struct device_node *child; + + dcc = devm_kzalloc(dev, sizeof(*dcc), GFP_KERNEL); + if (!dcc) + return -ENOMEM; + + dcc->dev = dev; + platform_set_drvdata(pdev, dcc); + + dcc->irq_base = devm_platform_ioremap_resource_byname(pdev, "irq"); + if (IS_ERR(dcc->irq_base)) + return PTR_ERR(dcc->irq_base); + + writel_relaxed(0, dcc->irq_base + IRQ_MASK); + writel_relaxed(~0, dcc->irq_base + IRQ_FLAG); + + dcc->domain = irq_domain_add_linear(dev->of_node, DOCKCHANNEL_MAX_IRQ, + &dockchannel_irq_domain_ops, dcc); + if (!dcc->domain) + return -ENOMEM; + + dcc->irq = platform_get_irq(pdev, 0); + if (dcc->irq <= 0) + return dev_err_probe(dev, dcc->irq, "Failed to get IRQ"); + + irq_set_handler_data(dcc->irq, dcc); + irq_set_chained_handler(dcc->irq, dockchannel_irq); + + for_each_child_of_node(dev->of_node, child) + of_platform_device_create(child, NULL, dev); + + return 0; +} + +static void dockchannel_remove(struct platform_device *pdev) +{ + struct dockchannel_common *dcc = platform_get_drvdata(pdev); + int hwirq; + + device_for_each_child(&pdev->dev, NULL, of_platform_device_destroy); + + irq_set_chained_handler_and_data(dcc->irq, NULL, NULL); + + for (hwirq = 0; hwirq < DOCKCHANNEL_MAX_IRQ; hwirq++) + irq_dispose_mapping(irq_find_mapping(dcc->domain, hwirq)); + + irq_domain_remove(dcc->domain); + + writel_relaxed(0, dcc->irq_base + IRQ_MASK); + writel_relaxed(~0, dcc->irq_base + IRQ_FLAG); +} + +static const struct of_device_id dockchannel_of_match[] = { + { .compatible = "apple,dockchannel" }, + {}, +}; +MODULE_DEVICE_TABLE(of, dockchannel_of_match); + +static struct platform_driver dockchannel_driver = { + .driver = { + .name = "dockchannel", + .of_match_table = dockchannel_of_match, + }, + .probe = dockchannel_probe, + .remove = dockchannel_remove, +}; +module_platform_driver(dockchannel_driver); + +MODULE_AUTHOR("Hector Martin "); +MODULE_LICENSE("Dual MIT/GPL"); +MODULE_DESCRIPTION("Apple DockChannel driver"); diff --git a/drivers/soc/apple/rtkit-helper.c b/drivers/soc/apple/rtkit-helper.c new file mode 100644 index 000000000000..080d083ed9bd --- /dev/null +++ b/drivers/soc/apple/rtkit-helper.c @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: GPL-2.0-only OR MIT +/* + * Apple Generic RTKit helper coprocessor + * Copyright The Asahi Linux Contributors + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#define APPLE_ASC_CPU_CONTROL 0x44 +#define APPLE_ASC_CPU_CONTROL_RUN BIT(4) + +struct apple_rtkit_helper { + struct device *dev; + struct apple_rtkit *rtk; + + void __iomem *asc_base; + + struct resource *sram; + void __iomem *sram_base; +}; + +static int apple_rtkit_helper_shmem_setup(void *cookie, struct apple_rtkit_shmem *bfr) +{ + struct apple_rtkit_helper *helper = cookie; + struct resource res = { + .start = bfr->iova, + .end = bfr->iova + bfr->size - 1, + .name = "rtkit_map", + }; + + if (!bfr->iova) { + bfr->buffer = dma_alloc_coherent(helper->dev, bfr->size, + &bfr->iova, GFP_KERNEL); + if (!bfr->buffer) + return -ENOMEM; + return 0; + } + + if (!helper->sram) { + dev_err(helper->dev, + "RTKit buffer request with no SRAM region: %pR", &res); + return -EFAULT; + } + + res.flags = helper->sram->flags; + + if (res.end < res.start || !resource_contains(helper->sram, &res)) { + dev_err(helper->dev, + "RTKit buffer request outside SRAM region: %pR", &res); + return -EFAULT; + } + + bfr->iomem = helper->sram_base + (res.start - helper->sram->start); + bfr->is_mapped = true; + + return 0; +} + +static void apple_rtkit_helper_shmem_destroy(void *cookie, struct apple_rtkit_shmem *bfr) +{ + // no-op +} + +static const struct apple_rtkit_ops apple_rtkit_helper_ops = { + .shmem_setup = apple_rtkit_helper_shmem_setup, + .shmem_destroy = apple_rtkit_helper_shmem_destroy, +}; + +static int apple_rtkit_helper_probe(struct platform_device *pdev) +{ + struct device *dev = &pdev->dev; + struct apple_rtkit_helper *helper; + int ret; + + /* 44 bits for addresses in standard RTKit requests */ + ret = dma_set_mask_and_coherent(&pdev->dev, DMA_BIT_MASK(44)); + if (ret) + return ret; + + helper = devm_kzalloc(dev, sizeof(*helper), GFP_KERNEL); + if (!helper) + return -ENOMEM; + + helper->dev = dev; + platform_set_drvdata(pdev, helper); + + helper->asc_base = devm_platform_ioremap_resource_byname(pdev, "asc"); + if (IS_ERR(helper->asc_base)) + return PTR_ERR(helper->asc_base); + + helper->sram = platform_get_resource_byname(pdev, IORESOURCE_MEM, "sram"); + if (helper->sram) { + helper->sram_base = devm_ioremap_resource(dev, helper->sram); + if (IS_ERR(helper->sram_base)) + return dev_err_probe(dev, PTR_ERR(helper->sram_base), + "Failed to map SRAM region"); + } + + helper->rtk = + devm_apple_rtkit_init(dev, helper, NULL, 0, &apple_rtkit_helper_ops); + if (IS_ERR(helper->rtk)) + return dev_err_probe(dev, PTR_ERR(helper->rtk), + "Failed to intialize RTKit"); + + writel_relaxed(APPLE_ASC_CPU_CONTROL_RUN, + helper->asc_base + APPLE_ASC_CPU_CONTROL); + + /* Works for both wake and boot */ + ret = apple_rtkit_wake(helper->rtk); + if (ret != 0) + return dev_err_probe(dev, ret, "Failed to wake up coprocessor"); + + return 0; +} + +static void apple_rtkit_helper_remove(struct platform_device *pdev) +{ + struct apple_rtkit_helper *helper = platform_get_drvdata(pdev); + + if (apple_rtkit_is_running(helper->rtk)) + apple_rtkit_quiesce(helper->rtk); + + writel_relaxed(0, helper->asc_base + APPLE_ASC_CPU_CONTROL); +} + +static const struct of_device_id apple_rtkit_helper_of_match[] = { + { .compatible = "apple,rtk-helper-asc4" }, + {}, +}; +MODULE_DEVICE_TABLE(of, apple_rtkit_helper_of_match); + +static struct platform_driver apple_rtkit_helper_driver = { + .driver = { + .name = "rtkit-helper", + .of_match_table = apple_rtkit_helper_of_match, + }, + .probe = apple_rtkit_helper_probe, + .remove = apple_rtkit_helper_remove, +}; +module_platform_driver(apple_rtkit_helper_driver); + +MODULE_AUTHOR("Hector Martin "); +MODULE_LICENSE("Dual MIT/GPL"); +MODULE_DESCRIPTION("Apple RTKit helper driver"); diff --git a/drivers/soc/apple/rtkit.c b/drivers/soc/apple/rtkit.c index e6d940292c9f..4b0783091a92 100644 --- a/drivers/soc/apple/rtkit.c +++ b/drivers/soc/apple/rtkit.c @@ -368,7 +368,7 @@ static void apple_rtkit_crashlog_rx(struct apple_rtkit *rtk, u64 msg) rtk->crashed = true; if (rtk->ops->crashed) - rtk->ops->crashed(rtk->cookie); + rtk->ops->crashed(rtk->cookie, bfr, rtk->crashlog_buffer.size); } static void apple_rtkit_ioreport_rx(struct apple_rtkit *rtk, u64 msg) diff --git a/drivers/staging/Kconfig b/drivers/staging/Kconfig index 075e775d3868..e1cc0d60eeb6 100644 --- a/drivers/staging/Kconfig +++ b/drivers/staging/Kconfig @@ -50,4 +50,6 @@ source "drivers/staging/vme_user/Kconfig" source "drivers/staging/gpib/Kconfig" +source "drivers/staging/apple-bce/Kconfig" + endif # STAGING diff --git a/drivers/staging/Makefile b/drivers/staging/Makefile index e681e403509c..4045c588b3b4 100644 --- a/drivers/staging/Makefile +++ b/drivers/staging/Makefile @@ -14,3 +14,4 @@ obj-$(CONFIG_GREYBUS) += greybus/ obj-$(CONFIG_BCM2835_VCHIQ) += vc04_services/ obj-$(CONFIG_XIL_AXIS_FIFO) += axis-fifo/ obj-$(CONFIG_GPIB) += gpib/ +obj-$(CONFIG_APPLE_BCE) += apple-bce/ diff --git a/drivers/staging/apple-bce/Kconfig b/drivers/staging/apple-bce/Kconfig new file mode 100644 index 000000000000..fe92bc441e89 --- /dev/null +++ b/drivers/staging/apple-bce/Kconfig @@ -0,0 +1,18 @@ +config APPLE_BCE + tristate "Apple BCE driver (VHCI and Audio support)" + default m + depends on X86 + select SOUND + select SND + select SND_PCM + select SND_JACK + help + VHCI and audio support on Apple MacBooks with the T2 Chip. + This driver is divided in three components: + - BCE (Buffer Copy Engine): which establishes a basic communication + channel with the T2 chip. This component is required by the other two: + - VHCI (Virtual Host Controller Interface): Access to keyboard, mouse + and other system devices depend on this virtual USB host controller + - Audio: a driver for the T2 audio interface. + + If "M" is selected, the module will be called apple-bce.' diff --git a/drivers/staging/apple-bce/Makefile b/drivers/staging/apple-bce/Makefile new file mode 100644 index 000000000000..8cfbd3f64af6 --- /dev/null +++ b/drivers/staging/apple-bce/Makefile @@ -0,0 +1,28 @@ +modname := apple-bce +obj-$(CONFIG_APPLE_BCE) += $(modname).o + +apple-bce-objs := apple_bce.o mailbox.o queue.o queue_dma.o vhci/vhci.o vhci/queue.o vhci/transfer.o audio/audio.o audio/protocol.o audio/protocol_bce.o audio/pcm.o + +MY_CFLAGS += -DWITHOUT_NVME_PATCH +#MY_CFLAGS += -g -DDEBUG +ccflags-y += ${MY_CFLAGS} +CC += ${MY_CFLAGS} + +KVERSION := $(KERNELRELEASE) +ifeq ($(origin KERNELRELEASE), undefined) +KVERSION := $(shell uname -r) +endif + +KDIR := /lib/modules/$(KVERSION)/build +PWD := $(shell pwd) + +.PHONY: all + +all: + $(MAKE) -C $(KDIR) M=$(PWD) modules + +clean: + $(MAKE) -C $(KDIR) M=$(PWD) clean + +install: + $(MAKE) -C $(KDIR) M=$(PWD) modules_install diff --git a/drivers/staging/apple-bce/apple_bce.c b/drivers/staging/apple-bce/apple_bce.c new file mode 100644 index 000000000000..4fd2415d7028 --- /dev/null +++ b/drivers/staging/apple-bce/apple_bce.c @@ -0,0 +1,445 @@ +#include "apple_bce.h" +#include +#include +#include "audio/audio.h" +#include + +static dev_t bce_chrdev; +static struct class *bce_class; + +struct apple_bce_device *global_bce; + +static int bce_create_command_queues(struct apple_bce_device *bce); +static void bce_free_command_queues(struct apple_bce_device *bce); +static irqreturn_t bce_handle_mb_irq(int irq, void *dev); +static irqreturn_t bce_handle_dma_irq(int irq, void *dev); +static int bce_fw_version_handshake(struct apple_bce_device *bce); +static int bce_register_command_queue(struct apple_bce_device *bce, struct bce_queue_memcfg *cfg, int is_sq); + +static int apple_bce_probe(struct pci_dev *dev, const struct pci_device_id *id) +{ + struct apple_bce_device *bce = NULL; + int status = 0; + int nvec; + + pr_info("apple-bce: capturing our device\n"); + + if (pci_enable_device(dev)) + return -ENODEV; + if (pci_request_regions(dev, "apple-bce")) { + status = -ENODEV; + goto fail; + } + pci_set_master(dev); + nvec = pci_alloc_irq_vectors(dev, 1, 8, PCI_IRQ_MSI); + if (nvec < 5) { + status = -EINVAL; + goto fail; + } + + bce = kzalloc(sizeof(struct apple_bce_device), GFP_KERNEL); + if (!bce) { + status = -ENOMEM; + goto fail; + } + + bce->pci = dev; + pci_set_drvdata(dev, bce); + + bce->devt = bce_chrdev; + bce->dev = device_create(bce_class, &dev->dev, bce->devt, NULL, "apple-bce"); + if (IS_ERR_OR_NULL(bce->dev)) { + status = PTR_ERR(bce_class); + goto fail; + } + + bce->reg_mem_mb = pci_iomap(dev, 4, 0); + bce->reg_mem_dma = pci_iomap(dev, 2, 0); + + if (IS_ERR_OR_NULL(bce->reg_mem_mb) || IS_ERR_OR_NULL(bce->reg_mem_dma)) { + dev_warn(&dev->dev, "apple-bce: Failed to pci_iomap required regions\n"); + goto fail; + } + + bce_mailbox_init(&bce->mbox, bce->reg_mem_mb); + bce_timestamp_init(&bce->timestamp, bce->reg_mem_mb); + + spin_lock_init(&bce->queues_lock); + ida_init(&bce->queue_ida); + + if ((status = pci_request_irq(dev, 0, bce_handle_mb_irq, NULL, dev, "bce_mbox"))) + goto fail; + if ((status = pci_request_irq(dev, 4, NULL, bce_handle_dma_irq, dev, "bce_dma"))) + goto fail_interrupt_0; + + if ((status = dma_set_mask_and_coherent(&dev->dev, DMA_BIT_MASK(37)))) { + dev_warn(&dev->dev, "dma: Setting mask failed\n"); + goto fail_interrupt; + } + + /* Gets the function 0's interface. This is needed because Apple only accepts DMA on our function if function 0 + is a bus master, so we need to work around this. */ + bce->pci0 = pci_get_slot(dev->bus, PCI_DEVFN(PCI_SLOT(dev->devfn), 0)); +#ifndef WITHOUT_NVME_PATCH + if ((status = pci_enable_device_mem(bce->pci0))) { + dev_warn(&dev->dev, "apple-bce: failed to enable function 0\n"); + goto fail_dev0; + } +#endif + pci_set_master(bce->pci0); + + bce_timestamp_start(&bce->timestamp, true); + + if ((status = bce_fw_version_handshake(bce))) + goto fail_ts; + pr_info("apple-bce: handshake done\n"); + + if ((status = bce_create_command_queues(bce))) { + pr_info("apple-bce: Creating command queues failed\n"); + goto fail_ts; + } + + global_bce = bce; + + bce_vhci_create(bce, &bce->vhci); + + return 0; + +fail_ts: + bce_timestamp_stop(&bce->timestamp); +#ifndef WITHOUT_NVME_PATCH + pci_disable_device(bce->pci0); +fail_dev0: +#endif + pci_dev_put(bce->pci0); +fail_interrupt: + pci_free_irq(dev, 4, dev); +fail_interrupt_0: + pci_free_irq(dev, 0, dev); +fail: + if (bce && bce->dev) { + device_destroy(bce_class, bce->devt); + + if (!IS_ERR_OR_NULL(bce->reg_mem_mb)) + pci_iounmap(dev, bce->reg_mem_mb); + if (!IS_ERR_OR_NULL(bce->reg_mem_dma)) + pci_iounmap(dev, bce->reg_mem_dma); + + kfree(bce); + } + + pci_free_irq_vectors(dev); + pci_release_regions(dev); + pci_disable_device(dev); + + if (!status) + status = -EINVAL; + return status; +} + +static int bce_create_command_queues(struct apple_bce_device *bce) +{ + int status; + struct bce_queue_memcfg *cfg; + + bce->cmd_cq = bce_alloc_cq(bce, 0, 0x20); + bce->cmd_cmdq = bce_alloc_cmdq(bce, 1, 0x20); + if (bce->cmd_cq == NULL || bce->cmd_cmdq == NULL) { + status = -ENOMEM; + goto err; + } + bce->queues[0] = (struct bce_queue *) bce->cmd_cq; + bce->queues[1] = (struct bce_queue *) bce->cmd_cmdq->sq; + + cfg = kzalloc(sizeof(struct bce_queue_memcfg), GFP_KERNEL); + if (!cfg) { + status = -ENOMEM; + goto err; + } + bce_get_cq_memcfg(bce->cmd_cq, cfg); + if ((status = bce_register_command_queue(bce, cfg, false))) + goto err; + bce_get_sq_memcfg(bce->cmd_cmdq->sq, bce->cmd_cq, cfg); + if ((status = bce_register_command_queue(bce, cfg, true))) + goto err; + kfree(cfg); + + return 0; + +err: + if (bce->cmd_cq) + bce_free_cq(bce, bce->cmd_cq); + if (bce->cmd_cmdq) + bce_free_cmdq(bce, bce->cmd_cmdq); + return status; +} + +static void bce_free_command_queues(struct apple_bce_device *bce) +{ + bce_free_cq(bce, bce->cmd_cq); + bce_free_cmdq(bce, bce->cmd_cmdq); + bce->cmd_cq = NULL; + bce->queues[0] = NULL; +} + +static irqreturn_t bce_handle_mb_irq(int irq, void *dev) +{ + struct apple_bce_device *bce = pci_get_drvdata(dev); + bce_mailbox_handle_interrupt(&bce->mbox); + return IRQ_HANDLED; +} + +static irqreturn_t bce_handle_dma_irq(int irq, void *dev) +{ + int i; + struct apple_bce_device *bce = pci_get_drvdata(dev); + spin_lock(&bce->queues_lock); + for (i = 0; i < BCE_MAX_QUEUE_COUNT; i++) + if (bce->queues[i] && bce->queues[i]->type == BCE_QUEUE_CQ) + bce_handle_cq_completions(bce, (struct bce_queue_cq *) bce->queues[i]); + spin_unlock(&bce->queues_lock); + return IRQ_HANDLED; +} + +static int bce_fw_version_handshake(struct apple_bce_device *bce) +{ + u64 result; + int status; + + if ((status = bce_mailbox_send(&bce->mbox, BCE_MB_MSG(BCE_MB_SET_FW_PROTOCOL_VERSION, BC_PROTOCOL_VERSION), + &result))) + return status; + if (BCE_MB_TYPE(result) != BCE_MB_SET_FW_PROTOCOL_VERSION || + BCE_MB_VALUE(result) != BC_PROTOCOL_VERSION) { + pr_err("apple-bce: FW version handshake failed %x:%llx\n", BCE_MB_TYPE(result), BCE_MB_VALUE(result)); + return -EINVAL; + } + return 0; +} + +static int bce_register_command_queue(struct apple_bce_device *bce, struct bce_queue_memcfg *cfg, int is_sq) +{ + int status; + int cmd_type; + u64 result; + // OS X uses an bidirectional direction, but that's not really needed + dma_addr_t a = dma_map_single(&bce->pci->dev, cfg, sizeof(struct bce_queue_memcfg), DMA_TO_DEVICE); + if (dma_mapping_error(&bce->pci->dev, a)) + return -ENOMEM; + cmd_type = is_sq ? BCE_MB_REGISTER_COMMAND_SQ : BCE_MB_REGISTER_COMMAND_CQ; + status = bce_mailbox_send(&bce->mbox, BCE_MB_MSG(cmd_type, a), &result); + dma_unmap_single(&bce->pci->dev, a, sizeof(struct bce_queue_memcfg), DMA_TO_DEVICE); + if (status) + return status; + if (BCE_MB_TYPE(result) != BCE_MB_REGISTER_COMMAND_QUEUE_REPLY) + return -EINVAL; + return 0; +} + +static void apple_bce_remove(struct pci_dev *dev) +{ + struct apple_bce_device *bce = pci_get_drvdata(dev); + bce->is_being_removed = true; + + bce_vhci_destroy(&bce->vhci); + + bce_timestamp_stop(&bce->timestamp); +#ifndef WITHOUT_NVME_PATCH + pci_disable_device(bce->pci0); +#endif + pci_dev_put(bce->pci0); + pci_free_irq(dev, 0, dev); + pci_free_irq(dev, 4, dev); + bce_free_command_queues(bce); + pci_iounmap(dev, bce->reg_mem_mb); + pci_iounmap(dev, bce->reg_mem_dma); + device_destroy(bce_class, bce->devt); + pci_free_irq_vectors(dev); + pci_release_regions(dev); + pci_disable_device(dev); + kfree(bce); +} + +static int bce_save_state_and_sleep(struct apple_bce_device *bce) +{ + int attempt, status = 0; + u64 resp; + dma_addr_t dma_addr; + void *dma_ptr = NULL; + size_t size = max(PAGE_SIZE, 4096UL); + + for (attempt = 0; attempt < 5; ++attempt) { + pr_debug("apple-bce: suspend: attempt %i, buffer size %li\n", attempt, size); + dma_ptr = dma_alloc_coherent(&bce->pci->dev, size, &dma_addr, GFP_KERNEL); + if (!dma_ptr) { + pr_err("apple-bce: suspend failed (data alloc failed)\n"); + break; + } + BUG_ON((dma_addr % 4096) != 0); + status = bce_mailbox_send(&bce->mbox, + BCE_MB_MSG(BCE_MB_SAVE_STATE_AND_SLEEP, (dma_addr & ~(4096LLU - 1)) | (size / 4096)), &resp); + if (status) { + pr_err("apple-bce: suspend failed (mailbox send)\n"); + break; + } + if (BCE_MB_TYPE(resp) == BCE_MB_SAVE_RESTORE_STATE_COMPLETE) { + bce->saved_data_dma_addr = dma_addr; + bce->saved_data_dma_ptr = dma_ptr; + bce->saved_data_dma_size = size; + return 0; + } else if (BCE_MB_TYPE(resp) == BCE_MB_SAVE_STATE_AND_SLEEP_FAILURE) { + dma_free_coherent(&bce->pci->dev, size, dma_ptr, dma_addr); + /* The 0x10ff magic value was extracted from Apple's driver */ + size = (BCE_MB_VALUE(resp) + 0x10ff) & ~(4096LLU - 1); + pr_debug("apple-bce: suspend: device requested a larger buffer (%li)\n", size); + continue; + } else { + pr_err("apple-bce: suspend failed (invalid device response)\n"); + status = -EINVAL; + break; + } + } + if (dma_ptr) + dma_free_coherent(&bce->pci->dev, size, dma_ptr, dma_addr); + if (!status) + return bce_mailbox_send(&bce->mbox, BCE_MB_MSG(BCE_MB_SLEEP_NO_STATE, 0), &resp); + return status; +} + +static int bce_restore_state_and_wake(struct apple_bce_device *bce) +{ + int status; + u64 resp; + if (!bce->saved_data_dma_ptr) { + if ((status = bce_mailbox_send(&bce->mbox, BCE_MB_MSG(BCE_MB_RESTORE_NO_STATE, 0), &resp))) { + pr_err("apple-bce: resume with no state failed (mailbox send)\n"); + return status; + } + if (BCE_MB_TYPE(resp) != BCE_MB_RESTORE_NO_STATE) { + pr_err("apple-bce: resume with no state failed (invalid device response)\n"); + return -EINVAL; + } + return 0; + } + + if ((status = bce_mailbox_send(&bce->mbox, BCE_MB_MSG(BCE_MB_RESTORE_STATE_AND_WAKE, + (bce->saved_data_dma_addr & ~(4096LLU - 1)) | (bce->saved_data_dma_size / 4096)), &resp))) { + pr_err("apple-bce: resume with state failed (mailbox send)\n"); + goto finish_with_state; + } + if (BCE_MB_TYPE(resp) != BCE_MB_SAVE_RESTORE_STATE_COMPLETE) { + pr_err("apple-bce: resume with state failed (invalid device response)\n"); + status = -EINVAL; + goto finish_with_state; + } + +finish_with_state: + dma_free_coherent(&bce->pci->dev, bce->saved_data_dma_size, bce->saved_data_dma_ptr, bce->saved_data_dma_addr); + bce->saved_data_dma_ptr = NULL; + return status; +} + +static int apple_bce_suspend(struct device *dev) +{ + struct apple_bce_device *bce = pci_get_drvdata(to_pci_dev(dev)); + int status; + + bce_timestamp_stop(&bce->timestamp); + + if ((status = bce_save_state_and_sleep(bce))) + return status; + + return 0; +} + +static int apple_bce_resume(struct device *dev) +{ + struct apple_bce_device *bce = pci_get_drvdata(to_pci_dev(dev)); + int status; + + pci_set_master(bce->pci); + pci_set_master(bce->pci0); + + if ((status = bce_restore_state_and_wake(bce))) + return status; + + bce_timestamp_start(&bce->timestamp, false); + + return 0; +} + +static struct pci_device_id apple_bce_ids[ ] = { + { PCI_DEVICE(PCI_VENDOR_ID_APPLE, 0x1801) }, + { 0, }, +}; + +MODULE_DEVICE_TABLE(pci, apple_bce_ids); + +struct dev_pm_ops apple_bce_pci_driver_pm = { + .suspend = apple_bce_suspend, + .resume = apple_bce_resume +}; +struct pci_driver apple_bce_pci_driver = { + .name = "apple-bce", + .id_table = apple_bce_ids, + .probe = apple_bce_probe, + .remove = apple_bce_remove, + .driver = { + .pm = &apple_bce_pci_driver_pm + } +}; + + +static int __init apple_bce_module_init(void) +{ + int result; + if ((result = alloc_chrdev_region(&bce_chrdev, 0, 1, "apple-bce"))) + goto fail_chrdev; +#if LINUX_VERSION_CODE < KERNEL_VERSION(6,4,0) + bce_class = class_create(THIS_MODULE, "apple-bce"); +#else + bce_class = class_create("apple-bce"); +#endif + if (IS_ERR(bce_class)) { + result = PTR_ERR(bce_class); + goto fail_class; + } + if ((result = bce_vhci_module_init())) { + pr_err("apple-bce: bce-vhci init failed"); + goto fail_class; + } + + result = pci_register_driver(&apple_bce_pci_driver); + if (result) + goto fail_drv; + + aaudio_module_init(); + + return 0; + +fail_drv: + pci_unregister_driver(&apple_bce_pci_driver); +fail_class: + class_destroy(bce_class); +fail_chrdev: + unregister_chrdev_region(bce_chrdev, 1); + if (!result) + result = -EINVAL; + return result; +} +static void __exit apple_bce_module_exit(void) +{ + pci_unregister_driver(&apple_bce_pci_driver); + + aaudio_module_exit(); + bce_vhci_module_exit(); + class_destroy(bce_class); + unregister_chrdev_region(bce_chrdev, 1); +} + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("MrARM"); +MODULE_DESCRIPTION("Apple BCE Driver"); +MODULE_VERSION("0.01"); +module_init(apple_bce_module_init); +module_exit(apple_bce_module_exit); diff --git a/drivers/staging/apple-bce/apple_bce.h b/drivers/staging/apple-bce/apple_bce.h new file mode 100644 index 000000000000..f13ab8d5742e --- /dev/null +++ b/drivers/staging/apple-bce/apple_bce.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include "mailbox.h" +#include "queue.h" +#include "vhci/vhci.h" + +#define BC_PROTOCOL_VERSION 0x20001 +#define BCE_MAX_QUEUE_COUNT 0x100 + +#define BCE_QUEUE_USER_MIN 2 +#define BCE_QUEUE_USER_MAX (BCE_MAX_QUEUE_COUNT - 1) + +struct apple_bce_device { + struct pci_dev *pci, *pci0; + dev_t devt; + struct device *dev; + void __iomem *reg_mem_mb; + void __iomem *reg_mem_dma; + struct bce_mailbox mbox; + struct bce_timestamp timestamp; + struct bce_queue *queues[BCE_MAX_QUEUE_COUNT]; + struct spinlock queues_lock; + struct ida queue_ida; + struct bce_queue_cq *cmd_cq; + struct bce_queue_cmdq *cmd_cmdq; + struct bce_queue_sq *int_sq_list[BCE_MAX_QUEUE_COUNT]; + bool is_being_removed; + + dma_addr_t saved_data_dma_addr; + void *saved_data_dma_ptr; + size_t saved_data_dma_size; + + struct bce_vhci vhci; +}; + +extern struct apple_bce_device *global_bce; \ No newline at end of file diff --git a/drivers/staging/apple-bce/audio/audio.c b/drivers/staging/apple-bce/audio/audio.c new file mode 100644 index 000000000000..bd16ddd16c1d --- /dev/null +++ b/drivers/staging/apple-bce/audio/audio.c @@ -0,0 +1,711 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include "audio.h" +#include "pcm.h" +#include + +static int aaudio_alsa_index = SNDRV_DEFAULT_IDX1; +static char *aaudio_alsa_id = SNDRV_DEFAULT_STR1; + +static dev_t aaudio_chrdev; +static struct class *aaudio_class; + +static int aaudio_init_cmd(struct aaudio_device *a); +static int aaudio_init_bs(struct aaudio_device *a); +static void aaudio_init_dev(struct aaudio_device *a, aaudio_device_id_t dev_id); +static void aaudio_free_dev(struct aaudio_subdevice *sdev); + +static int aaudio_probe(struct pci_dev *dev, const struct pci_device_id *id) +{ + struct aaudio_device *aaudio = NULL; + struct aaudio_subdevice *sdev = NULL; + int status = 0; + u32 cfg; + + pr_info("aaudio: capturing our device\n"); + + if (pci_enable_device(dev)) + return -ENODEV; + if (pci_request_regions(dev, "aaudio")) { + status = -ENODEV; + goto fail; + } + pci_set_master(dev); + + aaudio = kzalloc(sizeof(struct aaudio_device), GFP_KERNEL); + if (!aaudio) { + status = -ENOMEM; + goto fail; + } + + aaudio->bce = global_bce; + if (!aaudio->bce) { + dev_warn(&dev->dev, "aaudio: No BCE available\n"); + status = -EINVAL; + goto fail; + } + + aaudio->pci = dev; + pci_set_drvdata(dev, aaudio); + + aaudio->devt = aaudio_chrdev; + aaudio->dev = device_create(aaudio_class, &dev->dev, aaudio->devt, NULL, "aaudio"); + if (IS_ERR_OR_NULL(aaudio->dev)) { + status = PTR_ERR(aaudio_class); + goto fail; + } + device_link_add(aaudio->dev, aaudio->bce->dev, DL_FLAG_PM_RUNTIME | DL_FLAG_AUTOREMOVE_CONSUMER); + + init_completion(&aaudio->remote_alive); + INIT_LIST_HEAD(&aaudio->subdevice_list); + + /* Init: set an unknown flag in the bitset */ + if (pci_read_config_dword(dev, 4, &cfg)) + dev_warn(&dev->dev, "aaudio: pci_read_config_dword fail\n"); + if (pci_write_config_dword(dev, 4, cfg | 6u)) + dev_warn(&dev->dev, "aaudio: pci_write_config_dword fail\n"); + + dev_info(aaudio->dev, "aaudio: bs len = %llx\n", pci_resource_len(dev, 0)); + aaudio->reg_mem_bs_dma = pci_resource_start(dev, 0); + aaudio->reg_mem_bs = pci_iomap(dev, 0, 0); + aaudio->reg_mem_cfg = pci_iomap(dev, 4, 0); + + aaudio->reg_mem_gpr = (u32 __iomem *) ((u8 __iomem *) aaudio->reg_mem_cfg + 0xC000); + + if (IS_ERR_OR_NULL(aaudio->reg_mem_bs) || IS_ERR_OR_NULL(aaudio->reg_mem_cfg)) { + dev_warn(&dev->dev, "aaudio: Failed to pci_iomap required regions\n"); + goto fail; + } + + if (aaudio_bce_init(aaudio)) { + dev_warn(&dev->dev, "aaudio: Failed to init BCE command transport\n"); + goto fail; + } + + if (snd_card_new(aaudio->dev, aaudio_alsa_index, aaudio_alsa_id, THIS_MODULE, 0, &aaudio->card)) { + dev_err(&dev->dev, "aaudio: Failed to create ALSA card\n"); + goto fail; + } + + strcpy(aaudio->card->shortname, "Apple T2 Audio"); + strcpy(aaudio->card->longname, "Apple T2 Audio"); + strcpy(aaudio->card->mixername, "Apple T2 Audio"); + /* Dynamic alsa ids start at 100 */ + aaudio->next_alsa_id = 100; + + if (aaudio_init_cmd(aaudio)) { + dev_err(&dev->dev, "aaudio: Failed to initialize over BCE\n"); + goto fail_snd; + } + + if (aaudio_init_bs(aaudio)) { + dev_err(&dev->dev, "aaudio: Failed to initialize BufferStruct\n"); + goto fail_snd; + } + + if ((status = aaudio_cmd_set_remote_access(aaudio, AAUDIO_REMOTE_ACCESS_ON))) { + dev_err(&dev->dev, "Failed to set remote access\n"); + return status; + } + + if (snd_card_register(aaudio->card)) { + dev_err(&dev->dev, "aaudio: Failed to register ALSA sound device\n"); + goto fail_snd; + } + + list_for_each_entry(sdev, &aaudio->subdevice_list, list) { + struct aaudio_buffer_struct_device *dev = &aaudio->bs->devices[sdev->buf_id]; + + if (sdev->out_stream_cnt == 1 && !strcmp(dev->name, "Speaker")) { + struct snd_pcm_hardware *hw = sdev->out_streams[0].alsa_hw_desc; + + snprintf(aaudio->card->driver, sizeof(aaudio->card->driver) / sizeof(char), "AppleT2x%d", hw->channels_min); + } + } + + return 0; + +fail_snd: + snd_card_free(aaudio->card); +fail: + if (aaudio && aaudio->dev) + device_destroy(aaudio_class, aaudio->devt); + kfree(aaudio); + + if (!IS_ERR_OR_NULL(aaudio->reg_mem_bs)) + pci_iounmap(dev, aaudio->reg_mem_bs); + if (!IS_ERR_OR_NULL(aaudio->reg_mem_cfg)) + pci_iounmap(dev, aaudio->reg_mem_cfg); + + pci_release_regions(dev); + pci_disable_device(dev); + + if (!status) + status = -EINVAL; + return status; +} + + + +static void aaudio_remove(struct pci_dev *dev) +{ + struct aaudio_subdevice *sdev; + struct aaudio_device *aaudio = pci_get_drvdata(dev); + + snd_card_free(aaudio->card); + while (!list_empty(&aaudio->subdevice_list)) { + sdev = list_first_entry(&aaudio->subdevice_list, struct aaudio_subdevice, list); + list_del(&sdev->list); + aaudio_free_dev(sdev); + } + pci_iounmap(dev, aaudio->reg_mem_bs); + pci_iounmap(dev, aaudio->reg_mem_cfg); + device_destroy(aaudio_class, aaudio->devt); + pci_free_irq_vectors(dev); + pci_release_regions(dev); + pci_disable_device(dev); + kfree(aaudio); +} + +static int aaudio_suspend(struct device *dev) +{ + struct aaudio_device *aaudio = pci_get_drvdata(to_pci_dev(dev)); + + if (aaudio_cmd_set_remote_access(aaudio, AAUDIO_REMOTE_ACCESS_OFF)) + dev_warn(aaudio->dev, "Failed to reset remote access\n"); + + pci_disable_device(aaudio->pci); + return 0; +} + +static int aaudio_resume(struct device *dev) +{ + int status; + struct aaudio_device *aaudio = pci_get_drvdata(to_pci_dev(dev)); + + if ((status = pci_enable_device(aaudio->pci))) + return status; + pci_set_master(aaudio->pci); + + if ((status = aaudio_cmd_set_remote_access(aaudio, AAUDIO_REMOTE_ACCESS_ON))) { + dev_err(aaudio->dev, "Failed to set remote access\n"); + return status; + } + + return 0; +} + +static int aaudio_init_cmd(struct aaudio_device *a) +{ + int status; + struct aaudio_send_ctx sctx; + struct aaudio_msg buf; + u64 dev_cnt, dev_i; + aaudio_device_id_t *dev_l; + + if ((status = aaudio_send(a, &sctx, 500, + aaudio_msg_write_alive_notification, 1, 3))) { + dev_err(a->dev, "Sending alive notification failed\n"); + return status; + } + + if (wait_for_completion_timeout(&a->remote_alive, msecs_to_jiffies(500)) == 0) { + dev_err(a->dev, "Timed out waiting for remote\n"); + return -ETIMEDOUT; + } + dev_info(a->dev, "Continuing init\n"); + + buf = aaudio_reply_alloc(); + if ((status = aaudio_cmd_get_device_list(a, &buf, &dev_l, &dev_cnt))) { + dev_err(a->dev, "Failed to get device list\n"); + aaudio_reply_free(&buf); + return status; + } + for (dev_i = 0; dev_i < dev_cnt; ++dev_i) + aaudio_init_dev(a, dev_l[dev_i]); + aaudio_reply_free(&buf); + + return 0; +} + +static void aaudio_init_stream_info(struct aaudio_subdevice *sdev, struct aaudio_stream *strm); +static void aaudio_handle_jack_connection_change(struct aaudio_subdevice *sdev); + +static void aaudio_init_dev(struct aaudio_device *a, aaudio_device_id_t dev_id) +{ + struct aaudio_subdevice *sdev; + struct aaudio_msg buf = aaudio_reply_alloc(); + u64 uid_len, stream_cnt, i; + aaudio_object_id_t *stream_list; + char *uid; + + sdev = kzalloc(sizeof(struct aaudio_subdevice), GFP_KERNEL); + + if (aaudio_cmd_get_property(a, &buf, dev_id, dev_id, AAUDIO_PROP(AAUDIO_PROP_SCOPE_GLOBAL, AAUDIO_PROP_UID, 0), + NULL, 0, (void **) &uid, &uid_len) || uid_len > AAUDIO_DEVICE_MAX_UID_LEN) { + dev_err(a->dev, "Failed to get device uid for device %llx\n", dev_id); + goto fail; + } + dev_info(a->dev, "Remote device %llx %.*s\n", dev_id, (int) uid_len, uid); + + sdev->a = a; + INIT_LIST_HEAD(&sdev->list); + sdev->dev_id = dev_id; + sdev->buf_id = AAUDIO_BUFFER_ID_NONE; + strncpy(sdev->uid, uid, uid_len); + sdev->uid[uid_len + 1] = '\0'; + + if (aaudio_cmd_get_primitive_property(a, dev_id, dev_id, + AAUDIO_PROP(AAUDIO_PROP_SCOPE_INPUT, AAUDIO_PROP_LATENCY, 0), NULL, 0, &sdev->in_latency, sizeof(u32))) + dev_warn(a->dev, "Failed to query device input latency\n"); + if (aaudio_cmd_get_primitive_property(a, dev_id, dev_id, + AAUDIO_PROP(AAUDIO_PROP_SCOPE_OUTPUT, AAUDIO_PROP_LATENCY, 0), NULL, 0, &sdev->out_latency, sizeof(u32))) + dev_warn(a->dev, "Failed to query device output latency\n"); + + if (aaudio_cmd_get_input_stream_list(a, &buf, dev_id, &stream_list, &stream_cnt)) { + dev_err(a->dev, "Failed to get input stream list for device %llx\n", dev_id); + goto fail; + } + if (stream_cnt > AAUDIO_DEIVCE_MAX_INPUT_STREAMS) { + dev_warn(a->dev, "Device %s input stream count %llu is larger than the supported count of %u\n", + sdev->uid, stream_cnt, AAUDIO_DEIVCE_MAX_INPUT_STREAMS); + stream_cnt = AAUDIO_DEIVCE_MAX_INPUT_STREAMS; + } + sdev->in_stream_cnt = stream_cnt; + for (i = 0; i < stream_cnt; i++) { + sdev->in_streams[i].id = stream_list[i]; + sdev->in_streams[i].buffer_cnt = 0; + aaudio_init_stream_info(sdev, &sdev->in_streams[i]); + sdev->in_streams[i].latency += sdev->in_latency; + } + + if (aaudio_cmd_get_output_stream_list(a, &buf, dev_id, &stream_list, &stream_cnt)) { + dev_err(a->dev, "Failed to get output stream list for device %llx\n", dev_id); + goto fail; + } + if (stream_cnt > AAUDIO_DEIVCE_MAX_OUTPUT_STREAMS) { + dev_warn(a->dev, "Device %s input stream count %llu is larger than the supported count of %u\n", + sdev->uid, stream_cnt, AAUDIO_DEIVCE_MAX_OUTPUT_STREAMS); + stream_cnt = AAUDIO_DEIVCE_MAX_OUTPUT_STREAMS; + } + sdev->out_stream_cnt = stream_cnt; + for (i = 0; i < stream_cnt; i++) { + sdev->out_streams[i].id = stream_list[i]; + sdev->out_streams[i].buffer_cnt = 0; + aaudio_init_stream_info(sdev, &sdev->out_streams[i]); + sdev->out_streams[i].latency += sdev->in_latency; + } + + if (sdev->is_pcm) + aaudio_create_pcm(sdev); + /* Headphone Jack status */ + if (!strcmp(sdev->uid, "Codec Output")) { + if (snd_jack_new(a->card, sdev->uid, SND_JACK_HEADPHONE, &sdev->jack, true, false)) + dev_warn(a->dev, "Failed to create an attached jack for %s\n", sdev->uid); + aaudio_cmd_property_listener(a, sdev->dev_id, sdev->dev_id, + AAUDIO_PROP(AAUDIO_PROP_SCOPE_OUTPUT, AAUDIO_PROP_JACK_PLUGGED, 0)); + aaudio_handle_jack_connection_change(sdev); + } + + aaudio_reply_free(&buf); + + list_add_tail(&sdev->list, &a->subdevice_list); + return; + +fail: + aaudio_reply_free(&buf); + kfree(sdev); +} + +static void aaudio_init_stream_info(struct aaudio_subdevice *sdev, struct aaudio_stream *strm) +{ + if (aaudio_cmd_get_primitive_property(sdev->a, sdev->dev_id, strm->id, + AAUDIO_PROP(AAUDIO_PROP_SCOPE_GLOBAL, AAUDIO_PROP_PHYS_FORMAT, 0), NULL, 0, + &strm->desc, sizeof(strm->desc))) + dev_warn(sdev->a->dev, "Failed to query stream descriptor\n"); + if (aaudio_cmd_get_primitive_property(sdev->a, sdev->dev_id, strm->id, + AAUDIO_PROP(AAUDIO_PROP_SCOPE_GLOBAL, AAUDIO_PROP_LATENCY, 0), NULL, 0, &strm->latency, sizeof(u32))) + dev_warn(sdev->a->dev, "Failed to query stream latency\n"); + if (strm->desc.format_id == AAUDIO_FORMAT_LPCM) + sdev->is_pcm = true; +} + +static void aaudio_free_dev(struct aaudio_subdevice *sdev) +{ + size_t i; + for (i = 0; i < sdev->in_stream_cnt; i++) { + if (sdev->in_streams[i].alsa_hw_desc) + kfree(sdev->in_streams[i].alsa_hw_desc); + if (sdev->in_streams[i].buffers) + kfree(sdev->in_streams[i].buffers); + } + for (i = 0; i < sdev->out_stream_cnt; i++) { + if (sdev->out_streams[i].alsa_hw_desc) + kfree(sdev->out_streams[i].alsa_hw_desc); + if (sdev->out_streams[i].buffers) + kfree(sdev->out_streams[i].buffers); + } + kfree(sdev); +} + +static struct aaudio_subdevice *aaudio_find_dev_by_dev_id(struct aaudio_device *a, aaudio_device_id_t dev_id) +{ + struct aaudio_subdevice *sdev; + list_for_each_entry(sdev, &a->subdevice_list, list) { + if (dev_id == sdev->dev_id) + return sdev; + } + return NULL; +} + +static struct aaudio_subdevice *aaudio_find_dev_by_uid(struct aaudio_device *a, const char *uid) +{ + struct aaudio_subdevice *sdev; + list_for_each_entry(sdev, &a->subdevice_list, list) { + if (!strcmp(uid, sdev->uid)) + return sdev; + } + return NULL; +} + +static void aaudio_init_bs_stream(struct aaudio_device *a, struct aaudio_stream *strm, + struct aaudio_buffer_struct_stream *bs_strm); +static void aaudio_init_bs_stream_host(struct aaudio_device *a, struct aaudio_stream *strm, + struct aaudio_buffer_struct_stream *bs_strm); + +static int aaudio_init_bs(struct aaudio_device *a) +{ + int i, j; + struct aaudio_buffer_struct_device *dev; + struct aaudio_subdevice *sdev; + u32 ver, sig, bs_base; + + ver = ioread32(&a->reg_mem_gpr[0]); + if (ver < 3) { + dev_err(a->dev, "aaudio: Bad GPR version (%u)", ver); + return -EINVAL; + } + sig = ioread32(&a->reg_mem_gpr[1]); + if (sig != AAUDIO_SIG) { + dev_err(a->dev, "aaudio: Bad GPR sig (%x)", sig); + return -EINVAL; + } + bs_base = ioread32(&a->reg_mem_gpr[2]); + a->bs = (struct aaudio_buffer_struct *) ((u8 *) a->reg_mem_bs + bs_base); + if (a->bs->signature != AAUDIO_SIG) { + dev_err(a->dev, "aaudio: Bad BufferStruct sig (%x)", a->bs->signature); + return -EINVAL; + } + dev_info(a->dev, "aaudio: BufferStruct ver = %i\n", a->bs->version); + dev_info(a->dev, "aaudio: Num devices = %i\n", a->bs->num_devices); + for (i = 0; i < a->bs->num_devices; i++) { + dev = &a->bs->devices[i]; + dev_info(a->dev, "aaudio: Device %i %s\n", i, dev->name); + + sdev = aaudio_find_dev_by_uid(a, dev->name); + if (!sdev) { + dev_err(a->dev, "aaudio: Subdevice not found for BufferStruct device %s\n", dev->name); + continue; + } + sdev->buf_id = (u8) i; + dev->num_input_streams = 0; + for (j = 0; j < dev->num_output_streams; j++) { + dev_info(a->dev, "aaudio: Device %i Stream %i: Output; Buffer Count = %i\n", i, j, + dev->output_streams[j].num_buffers); + if (j < sdev->out_stream_cnt) + aaudio_init_bs_stream(a, &sdev->out_streams[j], &dev->output_streams[j]); + } + } + + list_for_each_entry(sdev, &a->subdevice_list, list) { + if (sdev->buf_id != AAUDIO_BUFFER_ID_NONE) + continue; + sdev->buf_id = i; + dev_info(a->dev, "aaudio: Created device %i %s\n", i, sdev->uid); + strcpy(a->bs->devices[i].name, sdev->uid); + a->bs->devices[i].num_input_streams = 0; + a->bs->devices[i].num_output_streams = 0; + a->bs->num_devices = ++i; + } + list_for_each_entry(sdev, &a->subdevice_list, list) { + if (sdev->in_stream_cnt == 1) { + dev_info(a->dev, "aaudio: Device %i Host Stream; Input\n", sdev->buf_id); + aaudio_init_bs_stream_host(a, &sdev->in_streams[0], &a->bs->devices[sdev->buf_id].input_streams[0]); + a->bs->devices[sdev->buf_id].num_input_streams = 1; + wmb(); + + if (aaudio_cmd_set_input_stream_address_ranges(a, sdev->dev_id)) { + dev_err(a->dev, "aaudio: Failed to set input stream address ranges\n"); + } + } + } + + return 0; +} + +static void aaudio_init_bs_stream(struct aaudio_device *a, struct aaudio_stream *strm, + struct aaudio_buffer_struct_stream *bs_strm) +{ + size_t i; + strm->buffer_cnt = bs_strm->num_buffers; + if (bs_strm->num_buffers > AAUDIO_DEIVCE_MAX_BUFFER_COUNT) { + dev_warn(a->dev, "BufferStruct buffer count %u exceeds driver limit of %u\n", bs_strm->num_buffers, + AAUDIO_DEIVCE_MAX_BUFFER_COUNT); + strm->buffer_cnt = AAUDIO_DEIVCE_MAX_BUFFER_COUNT; + } + if (!strm->buffer_cnt) + return; + strm->buffers = kmalloc_array(strm->buffer_cnt, sizeof(struct aaudio_dma_buf), GFP_KERNEL); + if (!strm->buffers) { + dev_err(a->dev, "Buffer list allocation failed\n"); + return; + } + for (i = 0; i < strm->buffer_cnt; i++) { + strm->buffers[i].dma_addr = a->reg_mem_bs_dma + (dma_addr_t) bs_strm->buffers[i].address; + strm->buffers[i].ptr = a->reg_mem_bs + bs_strm->buffers[i].address; + strm->buffers[i].size = bs_strm->buffers[i].size; + } + + if (strm->buffer_cnt == 1) { + strm->alsa_hw_desc = kmalloc(sizeof(struct snd_pcm_hardware), GFP_KERNEL); + if (aaudio_create_hw_info(&strm->desc, strm->alsa_hw_desc, strm->buffers[0].size)) { + kfree(strm->alsa_hw_desc); + strm->alsa_hw_desc = NULL; + } + } +} + +static void aaudio_init_bs_stream_host(struct aaudio_device *a, struct aaudio_stream *strm, + struct aaudio_buffer_struct_stream *bs_strm) +{ + size_t size; + dma_addr_t dma_addr; + void *dma_ptr; + size = strm->desc.bytes_per_packet * 16640; + dma_ptr = dma_alloc_coherent(&a->pci->dev, size, &dma_addr, GFP_KERNEL); + if (!dma_ptr) { + dev_err(a->dev, "dma_alloc_coherent failed\n"); + return; + } + bs_strm->buffers[0].address = dma_addr; + bs_strm->buffers[0].size = size; + bs_strm->num_buffers = 1; + + memset(dma_ptr, 0, size); + + strm->buffer_cnt = 1; + strm->buffers = kmalloc_array(strm->buffer_cnt, sizeof(struct aaudio_dma_buf), GFP_KERNEL); + if (!strm->buffers) { + dev_err(a->dev, "Buffer list allocation failed\n"); + return; + } + strm->buffers[0].dma_addr = dma_addr; + strm->buffers[0].ptr = dma_ptr; + strm->buffers[0].size = size; + + strm->alsa_hw_desc = kmalloc(sizeof(struct snd_pcm_hardware), GFP_KERNEL); + if (aaudio_create_hw_info(&strm->desc, strm->alsa_hw_desc, strm->buffers[0].size)) { + kfree(strm->alsa_hw_desc); + strm->alsa_hw_desc = NULL; + } +} + +static void aaudio_handle_prop_change(struct aaudio_device *a, struct aaudio_msg *msg); + +void aaudio_handle_notification(struct aaudio_device *a, struct aaudio_msg *msg) +{ + struct aaudio_send_ctx sctx; + struct aaudio_msg_base base; + if (aaudio_msg_read_base(msg, &base)) + return; + switch (base.msg) { + case AAUDIO_MSG_NOTIFICATION_BOOT: + dev_info(a->dev, "Received boot notification from remote\n"); + + /* Resend the alive notify */ + if (aaudio_send(a, &sctx, 500, + aaudio_msg_write_alive_notification, 1, 3)) { + pr_err("Sending alive notification failed\n"); + } + break; + case AAUDIO_MSG_NOTIFICATION_ALIVE: + dev_info(a->dev, "Received alive notification from remote\n"); + complete_all(&a->remote_alive); + break; + case AAUDIO_MSG_PROPERTY_CHANGED: + aaudio_handle_prop_change(a, msg); + break; + default: + dev_info(a->dev, "Unhandled notification %i", base.msg); + break; + } +} + +struct aaudio_prop_change_work_struct { + struct work_struct ws; + struct aaudio_device *a; + aaudio_device_id_t dev; + aaudio_object_id_t obj; + struct aaudio_prop_addr prop; +}; + +static void aaudio_handle_jack_connection_change(struct aaudio_subdevice *sdev) +{ + u32 plugged; + if (!sdev->jack) + return; + /* NOTE: Apple made the plug status scoped to the input and output streams. This makes no sense for us, so I just + * always pick the OUTPUT status. */ + if (aaudio_cmd_get_primitive_property(sdev->a, sdev->dev_id, sdev->dev_id, + AAUDIO_PROP(AAUDIO_PROP_SCOPE_OUTPUT, AAUDIO_PROP_JACK_PLUGGED, 0), NULL, 0, &plugged, sizeof(plugged))) { + dev_err(sdev->a->dev, "Failed to get jack enable status\n"); + return; + } + dev_dbg(sdev->a->dev, "Jack is now %s\n", plugged ? "plugged" : "unplugged"); + snd_jack_report(sdev->jack, plugged ? sdev->jack->type : 0); +} + +void aaudio_handle_prop_change_work(struct work_struct *ws) +{ + struct aaudio_prop_change_work_struct *work = container_of(ws, struct aaudio_prop_change_work_struct, ws); + struct aaudio_subdevice *sdev; + + sdev = aaudio_find_dev_by_dev_id(work->a, work->dev); + if (!sdev) { + dev_err(work->a->dev, "Property notification change: device not found\n"); + goto done; + } + dev_dbg(work->a->dev, "Property changed for device: %s\n", sdev->uid); + + if (work->prop.scope == AAUDIO_PROP_SCOPE_OUTPUT && work->prop.selector == AAUDIO_PROP_JACK_PLUGGED) { + aaudio_handle_jack_connection_change(sdev); + } + +done: + kfree(work); +} + +void aaudio_handle_prop_change(struct aaudio_device *a, struct aaudio_msg *msg) +{ + /* NOTE: This is a scheduled work because this callback will generally need to query device information and this + * is not possible when we are in the reply parsing code's context. */ + struct aaudio_prop_change_work_struct *work; + work = kmalloc(sizeof(struct aaudio_prop_change_work_struct), GFP_KERNEL); + work->a = a; + INIT_WORK(&work->ws, aaudio_handle_prop_change_work); + aaudio_msg_read_property_changed(msg, &work->dev, &work->obj, &work->prop); + schedule_work(&work->ws); +} + +#define aaudio_send_cmd_response(a, sctx, msg, fn, ...) \ + if (aaudio_send_with_tag(a, sctx, ((struct aaudio_msg_header *) msg->data)->tag, 500, fn, ##__VA_ARGS__)) \ + pr_err("aaudio: Failed to reply to a command\n"); + +void aaudio_handle_cmd_timestamp(struct aaudio_device *a, struct aaudio_msg *msg) +{ + ktime_t time_os = ktime_get_boottime(); + struct aaudio_send_ctx sctx; + struct aaudio_subdevice *sdev; + u64 devid, timestamp, update_seed; + aaudio_msg_read_update_timestamp(msg, &devid, ×tamp, &update_seed); + dev_dbg(a->dev, "Received timestamp update for dev=%llx ts=%llx seed=%llx\n", devid, timestamp, update_seed); + + sdev = aaudio_find_dev_by_dev_id(a, devid); + aaudio_handle_timestamp(sdev, time_os, timestamp); + + aaudio_send_cmd_response(a, &sctx, msg, + aaudio_msg_write_update_timestamp_response); +} + +void aaudio_handle_command(struct aaudio_device *a, struct aaudio_msg *msg) +{ + struct aaudio_msg_base base; + if (aaudio_msg_read_base(msg, &base)) + return; + switch (base.msg) { + case AAUDIO_MSG_UPDATE_TIMESTAMP: + aaudio_handle_cmd_timestamp(a, msg); + break; + default: + dev_info(a->dev, "Unhandled device command %i", base.msg); + break; + } +} + +static struct pci_device_id aaudio_ids[ ] = { + { PCI_DEVICE(PCI_VENDOR_ID_APPLE, 0x1803) }, + { 0, }, +}; + +struct dev_pm_ops aaudio_pci_driver_pm = { + .suspend = aaudio_suspend, + .resume = aaudio_resume +}; +struct pci_driver aaudio_pci_driver = { + .name = "aaudio", + .id_table = aaudio_ids, + .probe = aaudio_probe, + .remove = aaudio_remove, + .driver = { + .pm = &aaudio_pci_driver_pm + } +}; + + +int aaudio_module_init(void) +{ + int result; + if ((result = alloc_chrdev_region(&aaudio_chrdev, 0, 1, "aaudio"))) + goto fail_chrdev; +#if LINUX_VERSION_CODE < KERNEL_VERSION(6,4,0) + aaudio_class = class_create(THIS_MODULE, "aaudio"); +#else + aaudio_class = class_create("aaudio"); +#endif + if (IS_ERR(aaudio_class)) { + result = PTR_ERR(aaudio_class); + goto fail_class; + } + + result = pci_register_driver(&aaudio_pci_driver); + if (result) + goto fail_drv; + return 0; + +fail_drv: + pci_unregister_driver(&aaudio_pci_driver); +fail_class: + class_destroy(aaudio_class); +fail_chrdev: + unregister_chrdev_region(aaudio_chrdev, 1); + if (!result) + result = -EINVAL; + return result; +} + +void aaudio_module_exit(void) +{ + pci_unregister_driver(&aaudio_pci_driver); + class_destroy(aaudio_class); + unregister_chrdev_region(aaudio_chrdev, 1); +} + +struct aaudio_alsa_pcm_id_mapping aaudio_alsa_id_mappings[] = { + {"Speaker", 0}, + {"Digital Mic", 1}, + {"Codec Output", 2}, + {"Codec Input", 3}, + {"Bridge Loopback", 4}, + {} +}; + +module_param_named(index, aaudio_alsa_index, int, 0444); +MODULE_PARM_DESC(index, "Index value for Apple Internal Audio soundcard."); +module_param_named(id, aaudio_alsa_id, charp, 0444); +MODULE_PARM_DESC(id, "ID string for Apple Internal Audio soundcard."); diff --git a/drivers/staging/apple-bce/audio/audio.h b/drivers/staging/apple-bce/audio/audio.h new file mode 100644 index 000000000000..004bc1e22ea4 --- /dev/null +++ b/drivers/staging/apple-bce/audio/audio.h @@ -0,0 +1,125 @@ +#ifndef AAUDIO_H +#define AAUDIO_H + +#include +#include +#include "../apple_bce.h" +#include "protocol_bce.h" +#include "description.h" + +#define AAUDIO_SIG 0x19870423 + +#define AAUDIO_DEVICE_MAX_UID_LEN 128 +#define AAUDIO_DEIVCE_MAX_INPUT_STREAMS 1 +#define AAUDIO_DEIVCE_MAX_OUTPUT_STREAMS 1 +#define AAUDIO_DEIVCE_MAX_BUFFER_COUNT 1 + +#define AAUDIO_BUFFER_ID_NONE 0xffu + +struct snd_card; +struct snd_pcm; +struct snd_pcm_hardware; +struct snd_jack; + +struct __attribute__((packed)) __attribute__((aligned(4))) aaudio_buffer_struct_buffer { + size_t address; + size_t size; + size_t pad[4]; +}; +struct aaudio_buffer_struct_stream { + u8 num_buffers; + struct aaudio_buffer_struct_buffer buffers[100]; + char filler[32]; +}; +struct aaudio_buffer_struct_device { + char name[128]; + u8 num_input_streams; + u8 num_output_streams; + struct aaudio_buffer_struct_stream input_streams[5]; + struct aaudio_buffer_struct_stream output_streams[5]; + char filler[128]; +}; +struct aaudio_buffer_struct { + u32 version; + u32 signature; + u32 flags; + u8 num_devices; + struct aaudio_buffer_struct_device devices[20]; +}; + +struct aaudio_device; +struct aaudio_dma_buf { + dma_addr_t dma_addr; + void *ptr; + size_t size; +}; +struct aaudio_stream { + aaudio_object_id_t id; + size_t buffer_cnt; + struct aaudio_dma_buf *buffers; + + struct aaudio_apple_description desc; + struct snd_pcm_hardware *alsa_hw_desc; + u32 latency; + + bool waiting_for_first_ts; + + ktime_t remote_timestamp; + snd_pcm_sframes_t frame_min; + int started; +}; +struct aaudio_subdevice { + struct aaudio_device *a; + struct list_head list; + aaudio_device_id_t dev_id; + u32 in_latency, out_latency; + u8 buf_id; + int alsa_id; + char uid[AAUDIO_DEVICE_MAX_UID_LEN + 1]; + size_t in_stream_cnt; + struct aaudio_stream in_streams[AAUDIO_DEIVCE_MAX_INPUT_STREAMS]; + size_t out_stream_cnt; + struct aaudio_stream out_streams[AAUDIO_DEIVCE_MAX_OUTPUT_STREAMS]; + bool is_pcm; + struct snd_pcm *pcm; + struct snd_jack *jack; +}; +struct aaudio_alsa_pcm_id_mapping { + const char *name; + int alsa_id; +}; + +struct aaudio_device { + struct pci_dev *pci; + dev_t devt; + struct device *dev; + void __iomem *reg_mem_bs; + dma_addr_t reg_mem_bs_dma; + void __iomem *reg_mem_cfg; + + u32 __iomem *reg_mem_gpr; + + struct aaudio_buffer_struct *bs; + + struct apple_bce_device *bce; + struct aaudio_bce bcem; + + struct snd_card *card; + + struct list_head subdevice_list; + int next_alsa_id; + + struct completion remote_alive; +}; + +void aaudio_handle_notification(struct aaudio_device *a, struct aaudio_msg *msg); +void aaudio_handle_prop_change_work(struct work_struct *ws); +void aaudio_handle_cmd_timestamp(struct aaudio_device *a, struct aaudio_msg *msg); +void aaudio_handle_command(struct aaudio_device *a, struct aaudio_msg *msg); + +int aaudio_module_init(void); +void aaudio_module_exit(void); + +extern struct aaudio_alsa_pcm_id_mapping aaudio_alsa_id_mappings[]; + +#endif //AAUDIO_H diff --git a/drivers/staging/apple-bce/audio/description.h b/drivers/staging/apple-bce/audio/description.h new file mode 100644 index 000000000000..dfef3ab68f27 --- /dev/null +++ b/drivers/staging/apple-bce/audio/description.h @@ -0,0 +1,42 @@ +#ifndef AAUDIO_DESCRIPTION_H +#define AAUDIO_DESCRIPTION_H + +#include + +struct aaudio_apple_description { + u64 sample_rate_double; + u32 format_id; + u32 format_flags; + u32 bytes_per_packet; + u32 frames_per_packet; + u32 bytes_per_frame; + u32 channels_per_frame; + u32 bits_per_channel; + u32 reserved; +}; + +enum { + AAUDIO_FORMAT_LPCM = 0x6c70636d // 'lpcm' +}; + +enum { + AAUDIO_FORMAT_FLAG_FLOAT = 1, + AAUDIO_FORMAT_FLAG_BIG_ENDIAN = 2, + AAUDIO_FORMAT_FLAG_SIGNED = 4, + AAUDIO_FORMAT_FLAG_PACKED = 8, + AAUDIO_FORMAT_FLAG_ALIGNED_HIGH = 16, + AAUDIO_FORMAT_FLAG_NON_INTERLEAVED = 32, + AAUDIO_FORMAT_FLAG_NON_MIXABLE = 64 +}; + +static inline u64 aaudio_double_to_u64(u64 d) +{ + u8 sign = (u8) ((d >> 63) & 1); + s32 exp = (s32) ((d >> 52) & 0x7ff) - 1023; + u64 fr = d & ((1LL << 52) - 1); + if (sign || exp < 0) + return 0; + return (u64) ((1LL << exp) + (fr >> (52 - exp))); +} + +#endif //AAUDIO_DESCRIPTION_H diff --git a/drivers/staging/apple-bce/audio/pcm.c b/drivers/staging/apple-bce/audio/pcm.c new file mode 100644 index 000000000000..1026e10a9ac5 --- /dev/null +++ b/drivers/staging/apple-bce/audio/pcm.c @@ -0,0 +1,308 @@ +#include "pcm.h" +#include "audio.h" + +static u64 aaudio_get_alsa_fmtbit(struct aaudio_apple_description *desc) +{ + if (desc->format_flags & AAUDIO_FORMAT_FLAG_FLOAT) { + if (desc->bits_per_channel == 32) { + if (desc->format_flags & AAUDIO_FORMAT_FLAG_BIG_ENDIAN) + return SNDRV_PCM_FMTBIT_FLOAT_BE; + else + return SNDRV_PCM_FMTBIT_FLOAT_LE; + } else if (desc->bits_per_channel == 64) { + if (desc->format_flags & AAUDIO_FORMAT_FLAG_BIG_ENDIAN) + return SNDRV_PCM_FMTBIT_FLOAT64_BE; + else + return SNDRV_PCM_FMTBIT_FLOAT64_LE; + } else { + pr_err("aaudio: unsupported bits per channel for float format: %u\n", desc->bits_per_channel); + return 0; + } + } +#define DEFINE_BPC_OPTION(val, b) \ + case val: \ + if (desc->format_flags & AAUDIO_FORMAT_FLAG_BIG_ENDIAN) { \ + if (desc->format_flags & AAUDIO_FORMAT_FLAG_SIGNED) \ + return SNDRV_PCM_FMTBIT_S ## b ## BE; \ + else \ + return SNDRV_PCM_FMTBIT_U ## b ## BE; \ + } else { \ + if (desc->format_flags & AAUDIO_FORMAT_FLAG_SIGNED) \ + return SNDRV_PCM_FMTBIT_S ## b ## LE; \ + else \ + return SNDRV_PCM_FMTBIT_U ## b ## LE; \ + } + if (desc->format_flags & AAUDIO_FORMAT_FLAG_PACKED) { + switch (desc->bits_per_channel) { + case 8: + case 16: + case 32: + break; + DEFINE_BPC_OPTION(24, 24_3) + default: + pr_err("aaudio: unsupported bits per channel for packed format: %u\n", desc->bits_per_channel); + return 0; + } + } + if (desc->format_flags & AAUDIO_FORMAT_FLAG_ALIGNED_HIGH) { + switch (desc->bits_per_channel) { + DEFINE_BPC_OPTION(24, 32_) + default: + pr_err("aaudio: unsupported bits per channel for high-aligned format: %u\n", desc->bits_per_channel); + return 0; + } + } + switch (desc->bits_per_channel) { + case 8: + if (desc->format_flags & AAUDIO_FORMAT_FLAG_SIGNED) + return SNDRV_PCM_FMTBIT_S8; + else + return SNDRV_PCM_FMTBIT_U8; + DEFINE_BPC_OPTION(16, 16_) + DEFINE_BPC_OPTION(24, 24_) + DEFINE_BPC_OPTION(32, 32_) + default: + pr_err("aaudio: unsupported bits per channel: %u\n", desc->bits_per_channel); + return 0; + } +} +int aaudio_create_hw_info(struct aaudio_apple_description *desc, struct snd_pcm_hardware *alsa_hw, + size_t buf_size) +{ + uint rate; + alsa_hw->info = (SNDRV_PCM_INFO_MMAP | + SNDRV_PCM_INFO_BLOCK_TRANSFER | + SNDRV_PCM_INFO_MMAP_VALID | + SNDRV_PCM_INFO_DOUBLE); + if (desc->format_flags & AAUDIO_FORMAT_FLAG_NON_MIXABLE) + pr_warn("aaudio: unsupported hw flag: NON_MIXABLE\n"); + if (!(desc->format_flags & AAUDIO_FORMAT_FLAG_NON_INTERLEAVED)) + alsa_hw->info |= SNDRV_PCM_INFO_INTERLEAVED; + alsa_hw->formats = aaudio_get_alsa_fmtbit(desc); + if (!alsa_hw->formats) + return -EINVAL; + rate = (uint) aaudio_double_to_u64(desc->sample_rate_double); + alsa_hw->rates = snd_pcm_rate_to_rate_bit(rate); + alsa_hw->rate_min = rate; + alsa_hw->rate_max = rate; + alsa_hw->channels_min = desc->channels_per_frame; + alsa_hw->channels_max = desc->channels_per_frame; + alsa_hw->buffer_bytes_max = buf_size; + alsa_hw->period_bytes_min = desc->bytes_per_packet; + alsa_hw->period_bytes_max = desc->bytes_per_packet; + alsa_hw->periods_min = (uint) (buf_size / desc->bytes_per_packet); + alsa_hw->periods_max = (uint) (buf_size / desc->bytes_per_packet); + pr_debug("aaudio_create_hw_info: format = %llu, rate = %u/%u. channels = %u, periods = %u, period size = %lu\n", + alsa_hw->formats, alsa_hw->rate_min, alsa_hw->rates, alsa_hw->channels_min, alsa_hw->periods_min, + alsa_hw->period_bytes_min); + return 0; +} + +static struct aaudio_stream *aaudio_pcm_stream(struct snd_pcm_substream *substream) +{ + struct aaudio_subdevice *sdev = snd_pcm_substream_chip(substream); + if (substream->stream == SNDRV_PCM_STREAM_PLAYBACK) + return &sdev->out_streams[substream->number]; + else + return &sdev->in_streams[substream->number]; +} + +static int aaudio_pcm_open(struct snd_pcm_substream *substream) +{ + pr_debug("aaudio_pcm_open\n"); + substream->runtime->hw = *aaudio_pcm_stream(substream)->alsa_hw_desc; + + return 0; +} + +static int aaudio_pcm_close(struct snd_pcm_substream *substream) +{ + pr_debug("aaudio_pcm_close\n"); + return 0; +} + +static int aaudio_pcm_prepare(struct snd_pcm_substream *substream) +{ + return 0; +} + +static int aaudio_pcm_hw_params(struct snd_pcm_substream *substream, struct snd_pcm_hw_params *hw_params) +{ + struct aaudio_stream *astream = aaudio_pcm_stream(substream); + pr_debug("aaudio_pcm_hw_params\n"); + + if (!astream->buffer_cnt || !astream->buffers) + return -EINVAL; + + substream->runtime->dma_area = astream->buffers[0].ptr; + substream->runtime->dma_addr = astream->buffers[0].dma_addr; + substream->runtime->dma_bytes = astream->buffers[0].size; + return 0; +} + +static int aaudio_pcm_hw_free(struct snd_pcm_substream *substream) +{ + pr_debug("aaudio_pcm_hw_free\n"); + return 0; +} + +static void aaudio_pcm_start(struct snd_pcm_substream *substream) +{ + struct aaudio_subdevice *sdev = snd_pcm_substream_chip(substream); + struct aaudio_stream *stream = aaudio_pcm_stream(substream); + void *buf; + size_t s; + ktime_t time_start, time_end; + bool back_buffer; + time_start = ktime_get(); + + back_buffer = (substream->stream == SNDRV_PCM_STREAM_PLAYBACK); + + if (back_buffer) { + s = frames_to_bytes(substream->runtime, substream->runtime->control->appl_ptr); + buf = kmalloc(s, GFP_KERNEL); + memcpy_fromio(buf, substream->runtime->dma_area, s); + time_end = ktime_get(); + pr_debug("aaudio: Backed up the buffer in %lluns [%li]\n", ktime_to_ns(time_end - time_start), + substream->runtime->control->appl_ptr); + } + + stream->waiting_for_first_ts = true; + stream->frame_min = stream->latency; + + aaudio_cmd_start_io(sdev->a, sdev->dev_id); + if (back_buffer) + memcpy_toio(substream->runtime->dma_area, buf, s); + + time_end = ktime_get(); + pr_debug("aaudio: Started the audio device in %lluns\n", ktime_to_ns(time_end - time_start)); +} + +static int aaudio_pcm_trigger(struct snd_pcm_substream *substream, int cmd) +{ + struct aaudio_subdevice *sdev = snd_pcm_substream_chip(substream); + struct aaudio_stream *stream = aaudio_pcm_stream(substream); + pr_debug("aaudio_pcm_trigger %x\n", cmd); + + /* We only supports triggers on the #0 buffer */ + if (substream->number != 0) + return 0; + switch (cmd) { + case SNDRV_PCM_TRIGGER_START: + aaudio_pcm_start(substream); + stream->started = 1; + break; + case SNDRV_PCM_TRIGGER_STOP: + aaudio_cmd_stop_io(sdev->a, sdev->dev_id); + stream->started = 0; + break; + default: + return -EINVAL; + } + return 0; +} + +static snd_pcm_uframes_t aaudio_pcm_pointer(struct snd_pcm_substream *substream) +{ + struct aaudio_stream *stream = aaudio_pcm_stream(substream); + ktime_t time_from_start; + snd_pcm_sframes_t frames; + snd_pcm_sframes_t buffer_time_length; + + if (!stream->started || stream->waiting_for_first_ts) { + pr_warn("aaudio_pcm_pointer while not started\n"); + return 0; + } + + /* Approximate the pointer based on the last received timestamp */ + time_from_start = ktime_get_boottime() - stream->remote_timestamp; + buffer_time_length = NSEC_PER_SEC * substream->runtime->buffer_size / substream->runtime->rate; + frames = (ktime_to_ns(time_from_start) % buffer_time_length) * substream->runtime->buffer_size / buffer_time_length; + if (ktime_to_ns(time_from_start) < buffer_time_length) { + if (frames < stream->frame_min) + frames = stream->frame_min; + else + stream->frame_min = 0; + } else { + if (ktime_to_ns(time_from_start) < 2 * buffer_time_length) + stream->frame_min = frames; + else + stream->frame_min = 0; /* Heavy desync */ + } + frames -= stream->latency; + if (frames < 0) + frames += ((-frames - 1) / substream->runtime->buffer_size + 1) * substream->runtime->buffer_size; + return (snd_pcm_uframes_t) frames; +} + +static struct snd_pcm_ops aaudio_pcm_ops = { + .open = aaudio_pcm_open, + .close = aaudio_pcm_close, + .ioctl = snd_pcm_lib_ioctl, + .hw_params = aaudio_pcm_hw_params, + .hw_free = aaudio_pcm_hw_free, + .prepare = aaudio_pcm_prepare, + .trigger = aaudio_pcm_trigger, + .pointer = aaudio_pcm_pointer, + .mmap = snd_pcm_lib_mmap_iomem +}; + +int aaudio_create_pcm(struct aaudio_subdevice *sdev) +{ + struct snd_pcm *pcm; + struct aaudio_alsa_pcm_id_mapping *id_mapping; + int err; + + if (!sdev->is_pcm || (sdev->in_stream_cnt == 0 && sdev->out_stream_cnt == 0)) { + return -EINVAL; + } + + for (id_mapping = aaudio_alsa_id_mappings; id_mapping->name; id_mapping++) { + if (!strcmp(sdev->uid, id_mapping->name)) { + sdev->alsa_id = id_mapping->alsa_id; + break; + } + } + if (!id_mapping->name) + sdev->alsa_id = sdev->a->next_alsa_id++; + err = snd_pcm_new(sdev->a->card, sdev->uid, sdev->alsa_id, + (int) sdev->out_stream_cnt, (int) sdev->in_stream_cnt, &pcm); + if (err < 0) + return err; + pcm->private_data = sdev; + pcm->nonatomic = 1; + sdev->pcm = pcm; + strcpy(pcm->name, sdev->uid); + snd_pcm_set_ops(pcm, SNDRV_PCM_STREAM_PLAYBACK, &aaudio_pcm_ops); + snd_pcm_set_ops(pcm, SNDRV_PCM_STREAM_CAPTURE, &aaudio_pcm_ops); + return 0; +} + +static void aaudio_handle_stream_timestamp(struct snd_pcm_substream *substream, ktime_t timestamp) +{ + unsigned long flags; + struct aaudio_stream *stream; + + stream = aaudio_pcm_stream(substream); + snd_pcm_stream_lock_irqsave(substream, flags); + stream->remote_timestamp = timestamp; + if (stream->waiting_for_first_ts) { + stream->waiting_for_first_ts = false; + snd_pcm_stream_unlock_irqrestore(substream, flags); + return; + } + snd_pcm_stream_unlock_irqrestore(substream, flags); + snd_pcm_period_elapsed(substream); +} + +void aaudio_handle_timestamp(struct aaudio_subdevice *sdev, ktime_t os_timestamp, u64 dev_timestamp) +{ + struct snd_pcm_substream *substream; + + substream = sdev->pcm->streams[SNDRV_PCM_STREAM_PLAYBACK].substream; + if (substream) + aaudio_handle_stream_timestamp(substream, dev_timestamp); + substream = sdev->pcm->streams[SNDRV_PCM_STREAM_CAPTURE].substream; + if (substream) + aaudio_handle_stream_timestamp(substream, os_timestamp); +} diff --git a/drivers/staging/apple-bce/audio/pcm.h b/drivers/staging/apple-bce/audio/pcm.h new file mode 100644 index 000000000000..ea5f35fbe408 --- /dev/null +++ b/drivers/staging/apple-bce/audio/pcm.h @@ -0,0 +1,16 @@ +#ifndef AAUDIO_PCM_H +#define AAUDIO_PCM_H + +#include +#include + +struct aaudio_subdevice; +struct aaudio_apple_description; +struct snd_pcm_hardware; + +int aaudio_create_hw_info(struct aaudio_apple_description *desc, struct snd_pcm_hardware *alsa_hw, size_t buf_size); +int aaudio_create_pcm(struct aaudio_subdevice *sdev); + +void aaudio_handle_timestamp(struct aaudio_subdevice *sdev, ktime_t os_timestamp, u64 dev_timestamp); + +#endif //AAUDIO_PCM_H diff --git a/drivers/staging/apple-bce/audio/protocol.c b/drivers/staging/apple-bce/audio/protocol.c new file mode 100644 index 000000000000..2314813aeead --- /dev/null +++ b/drivers/staging/apple-bce/audio/protocol.c @@ -0,0 +1,347 @@ +#include "protocol.h" +#include "protocol_bce.h" +#include "audio.h" + +int aaudio_msg_read_base(struct aaudio_msg *msg, struct aaudio_msg_base *base) +{ + if (msg->size < sizeof(struct aaudio_msg_header) + sizeof(struct aaudio_msg_base) * 2) + return -EINVAL; + *base = *((struct aaudio_msg_base *) ((struct aaudio_msg_header *) msg->data + 1)); + return 0; +} + +#define READ_START(type) \ + size_t offset = sizeof(struct aaudio_msg_header) + sizeof(struct aaudio_msg_base); (void)offset; \ + if (((struct aaudio_msg_base *) ((struct aaudio_msg_header *) msg->data + 1))->msg != type) \ + return -EINVAL; +#define READ_DEVID_VAR(devid) *devid = ((struct aaudio_msg_header *) msg->data)->device_id +#define READ_VAL(type) ({ offset += sizeof(type); *((type *) ((u8 *) msg->data + offset - sizeof(type))); }) +#define READ_VAR(type, var) *var = READ_VAL(type) + +int aaudio_msg_read_start_io_response(struct aaudio_msg *msg) +{ + READ_START(AAUDIO_MSG_START_IO_RESPONSE); + return 0; +} + +int aaudio_msg_read_stop_io_response(struct aaudio_msg *msg) +{ + READ_START(AAUDIO_MSG_STOP_IO_RESPONSE); + return 0; +} + +int aaudio_msg_read_update_timestamp(struct aaudio_msg *msg, aaudio_device_id_t *devid, + u64 *timestamp, u64 *update_seed) +{ + READ_START(AAUDIO_MSG_UPDATE_TIMESTAMP); + READ_DEVID_VAR(devid); + READ_VAR(u64, timestamp); + READ_VAR(u64, update_seed); + return 0; +} + +int aaudio_msg_read_get_property_response(struct aaudio_msg *msg, aaudio_object_id_t *obj, + struct aaudio_prop_addr *prop, void **data, u64 *data_size) +{ + READ_START(AAUDIO_MSG_GET_PROPERTY_RESPONSE); + READ_VAR(aaudio_object_id_t, obj); + READ_VAR(u32, &prop->element); + READ_VAR(u32, &prop->scope); + READ_VAR(u32, &prop->selector); + READ_VAR(u64, data_size); + *data = ((u8 *) msg->data + offset); + /* offset += data_size; */ + return 0; +} + +int aaudio_msg_read_set_property_response(struct aaudio_msg *msg, aaudio_object_id_t *obj) +{ + READ_START(AAUDIO_MSG_SET_PROPERTY_RESPONSE); + READ_VAR(aaudio_object_id_t, obj); + return 0; +} + +int aaudio_msg_read_property_listener_response(struct aaudio_msg *msg, aaudio_object_id_t *obj, + struct aaudio_prop_addr *prop) +{ + READ_START(AAUDIO_MSG_PROPERTY_LISTENER_RESPONSE); + READ_VAR(aaudio_object_id_t, obj); + READ_VAR(u32, &prop->element); + READ_VAR(u32, &prop->scope); + READ_VAR(u32, &prop->selector); + return 0; +} + +int aaudio_msg_read_property_changed(struct aaudio_msg *msg, aaudio_device_id_t *devid, aaudio_object_id_t *obj, + struct aaudio_prop_addr *prop) +{ + READ_START(AAUDIO_MSG_PROPERTY_CHANGED); + READ_DEVID_VAR(devid); + READ_VAR(aaudio_object_id_t, obj); + READ_VAR(u32, &prop->element); + READ_VAR(u32, &prop->scope); + READ_VAR(u32, &prop->selector); + return 0; +} + +int aaudio_msg_read_set_input_stream_address_ranges_response(struct aaudio_msg *msg) +{ + READ_START(AAUDIO_MSG_SET_INPUT_STREAM_ADDRESS_RANGES_RESPONSE); + return 0; +} + +int aaudio_msg_read_get_input_stream_list_response(struct aaudio_msg *msg, aaudio_object_id_t **str_l, u64 *str_cnt) +{ + READ_START(AAUDIO_MSG_GET_INPUT_STREAM_LIST_RESPONSE); + READ_VAR(u64, str_cnt); + *str_l = (aaudio_device_id_t *) ((u8 *) msg->data + offset); + /* offset += str_cnt * sizeof(aaudio_object_id_t); */ + return 0; +} + +int aaudio_msg_read_get_output_stream_list_response(struct aaudio_msg *msg, aaudio_object_id_t **str_l, u64 *str_cnt) +{ + READ_START(AAUDIO_MSG_GET_OUTPUT_STREAM_LIST_RESPONSE); + READ_VAR(u64, str_cnt); + *str_l = (aaudio_device_id_t *) ((u8 *) msg->data + offset); + /* offset += str_cnt * sizeof(aaudio_object_id_t); */ + return 0; +} + +int aaudio_msg_read_set_remote_access_response(struct aaudio_msg *msg) +{ + READ_START(AAUDIO_MSG_SET_REMOTE_ACCESS_RESPONSE); + return 0; +} + +int aaudio_msg_read_get_device_list_response(struct aaudio_msg *msg, aaudio_device_id_t **dev_l, u64 *dev_cnt) +{ + READ_START(AAUDIO_MSG_GET_DEVICE_LIST_RESPONSE); + READ_VAR(u64, dev_cnt); + *dev_l = (aaudio_device_id_t *) ((u8 *) msg->data + offset); + /* offset += dev_cnt * sizeof(aaudio_device_id_t); */ + return 0; +} + +#define WRITE_START_OF_TYPE(typev, devid) \ + size_t offset = sizeof(struct aaudio_msg_header); (void) offset; \ + ((struct aaudio_msg_header *) msg->data)->type = (typev); \ + ((struct aaudio_msg_header *) msg->data)->device_id = (devid); +#define WRITE_START_COMMAND(devid) WRITE_START_OF_TYPE(AAUDIO_MSG_TYPE_COMMAND, devid) +#define WRITE_START_RESPONSE() WRITE_START_OF_TYPE(AAUDIO_MSG_TYPE_RESPONSE, 0) +#define WRITE_START_NOTIFICATION() WRITE_START_OF_TYPE(AAUDIO_MSG_TYPE_NOTIFICATION, 0) +#define WRITE_VAL(type, value) { *((type *) ((u8 *) msg->data + offset)) = value; offset += sizeof(value); } +#define WRITE_BIN(value, size) { memcpy((u8 *) msg->data + offset, value, size); offset += size; } +#define WRITE_BASE(type) WRITE_VAL(u32, type) WRITE_VAL(u32, 0) +#define WRITE_END() { msg->size = offset; } + +void aaudio_msg_write_start_io(struct aaudio_msg *msg, aaudio_device_id_t dev) +{ + WRITE_START_COMMAND(dev); + WRITE_BASE(AAUDIO_MSG_START_IO); + WRITE_END(); +} + +void aaudio_msg_write_stop_io(struct aaudio_msg *msg, aaudio_device_id_t dev) +{ + WRITE_START_COMMAND(dev); + WRITE_BASE(AAUDIO_MSG_STOP_IO); + WRITE_END(); +} + +void aaudio_msg_write_get_property(struct aaudio_msg *msg, aaudio_device_id_t dev, aaudio_object_id_t obj, + struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size) +{ + WRITE_START_COMMAND(dev); + WRITE_BASE(AAUDIO_MSG_GET_PROPERTY); + WRITE_VAL(aaudio_object_id_t, obj); + WRITE_VAL(u32, prop.element); + WRITE_VAL(u32, prop.scope); + WRITE_VAL(u32, prop.selector); + WRITE_VAL(u64, qualifier_size); + WRITE_BIN(qualifier, qualifier_size); + WRITE_END(); +} + +void aaudio_msg_write_set_property(struct aaudio_msg *msg, aaudio_device_id_t dev, aaudio_object_id_t obj, + struct aaudio_prop_addr prop, void *data, u64 data_size, void *qualifier, u64 qualifier_size) +{ + WRITE_START_COMMAND(dev); + WRITE_BASE(AAUDIO_MSG_SET_PROPERTY); + WRITE_VAL(aaudio_object_id_t, obj); + WRITE_VAL(u32, prop.element); + WRITE_VAL(u32, prop.scope); + WRITE_VAL(u32, prop.selector); + WRITE_VAL(u64, data_size); + WRITE_BIN(data, data_size); + WRITE_VAL(u64, qualifier_size); + WRITE_BIN(qualifier, qualifier_size); + WRITE_END(); +} + +void aaudio_msg_write_property_listener(struct aaudio_msg *msg, aaudio_device_id_t dev, aaudio_object_id_t obj, + struct aaudio_prop_addr prop) +{ + WRITE_START_COMMAND(dev); + WRITE_BASE(AAUDIO_MSG_PROPERTY_LISTENER); + WRITE_VAL(aaudio_object_id_t, obj); + WRITE_VAL(u32, prop.element); + WRITE_VAL(u32, prop.scope); + WRITE_VAL(u32, prop.selector); + WRITE_END(); +} + +void aaudio_msg_write_set_input_stream_address_ranges(struct aaudio_msg *msg, aaudio_device_id_t devid) +{ + WRITE_START_COMMAND(devid); + WRITE_BASE(AAUDIO_MSG_SET_INPUT_STREAM_ADDRESS_RANGES); + WRITE_END(); +} + +void aaudio_msg_write_get_input_stream_list(struct aaudio_msg *msg, aaudio_device_id_t devid) +{ + WRITE_START_COMMAND(devid); + WRITE_BASE(AAUDIO_MSG_GET_INPUT_STREAM_LIST); + WRITE_END(); +} + +void aaudio_msg_write_get_output_stream_list(struct aaudio_msg *msg, aaudio_device_id_t devid) +{ + WRITE_START_COMMAND(devid); + WRITE_BASE(AAUDIO_MSG_GET_OUTPUT_STREAM_LIST); + WRITE_END(); +} + +void aaudio_msg_write_set_remote_access(struct aaudio_msg *msg, u64 mode) +{ + WRITE_START_COMMAND(0); + WRITE_BASE(AAUDIO_MSG_SET_REMOTE_ACCESS); + WRITE_VAL(u64, mode); + WRITE_END(); +} + +void aaudio_msg_write_alive_notification(struct aaudio_msg *msg, u32 proto_ver, u32 msg_ver) +{ + WRITE_START_NOTIFICATION(); + WRITE_BASE(AAUDIO_MSG_NOTIFICATION_ALIVE); + WRITE_VAL(u32, proto_ver); + WRITE_VAL(u32, msg_ver); + WRITE_END(); +} + +void aaudio_msg_write_update_timestamp_response(struct aaudio_msg *msg) +{ + WRITE_START_RESPONSE(); + WRITE_BASE(AAUDIO_MSG_UPDATE_TIMESTAMP_RESPONSE); + WRITE_END(); +} + +void aaudio_msg_write_get_device_list(struct aaudio_msg *msg) +{ + WRITE_START_COMMAND(0); + WRITE_BASE(AAUDIO_MSG_GET_DEVICE_LIST); + WRITE_END(); +} + +#define CMD_SHARED_VARS_NO_REPLY \ + int status = 0; \ + struct aaudio_send_ctx sctx; +#define CMD_SHARED_VARS \ + CMD_SHARED_VARS_NO_REPLY \ + struct aaudio_msg reply = aaudio_reply_alloc(); \ + struct aaudio_msg *buf = &reply; +#define CMD_SEND_REQUEST(fn, ...) \ + if ((status = aaudio_send_cmd_sync(a, &sctx, buf, 500, fn, ##__VA_ARGS__))) \ + return status; +#define CMD_DEF_SHARED_AND_SEND(fn, ...) \ + CMD_SHARED_VARS \ + CMD_SEND_REQUEST(fn, ##__VA_ARGS__); +#define CMD_DEF_SHARED_NO_REPLY_AND_SEND(fn, ...) \ + CMD_SHARED_VARS_NO_REPLY \ + CMD_SEND_REQUEST(fn, ##__VA_ARGS__); +#define CMD_HNDL_REPLY_NO_FREE(fn, ...) \ + status = fn(buf, ##__VA_ARGS__); \ + return status; +#define CMD_HNDL_REPLY_AND_FREE(fn, ...) \ + status = fn(buf, ##__VA_ARGS__); \ + aaudio_reply_free(&reply); \ + return status; + +int aaudio_cmd_start_io(struct aaudio_device *a, aaudio_device_id_t devid) +{ + CMD_DEF_SHARED_AND_SEND(aaudio_msg_write_start_io, devid); + CMD_HNDL_REPLY_AND_FREE(aaudio_msg_read_start_io_response); +} +int aaudio_cmd_stop_io(struct aaudio_device *a, aaudio_device_id_t devid) +{ + CMD_DEF_SHARED_AND_SEND(aaudio_msg_write_stop_io, devid); + CMD_HNDL_REPLY_AND_FREE(aaudio_msg_read_stop_io_response); +} +int aaudio_cmd_get_property(struct aaudio_device *a, struct aaudio_msg *buf, + aaudio_device_id_t devid, aaudio_object_id_t obj, + struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size, void **data, u64 *data_size) +{ + CMD_DEF_SHARED_NO_REPLY_AND_SEND(aaudio_msg_write_get_property, devid, obj, prop, qualifier, qualifier_size); + CMD_HNDL_REPLY_NO_FREE(aaudio_msg_read_get_property_response, &obj, &prop, data, data_size); +} +int aaudio_cmd_get_primitive_property(struct aaudio_device *a, + aaudio_device_id_t devid, aaudio_object_id_t obj, + struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size, void *data, u64 data_size) +{ + int status; + struct aaudio_msg reply = aaudio_reply_alloc(); + void *r_data; + u64 r_data_size; + if ((status = aaudio_cmd_get_property(a, &reply, devid, obj, prop, qualifier, qualifier_size, + &r_data, &r_data_size))) + goto finish; + if (r_data_size != data_size) { + status = -EINVAL; + goto finish; + } + memcpy(data, r_data, data_size); +finish: + aaudio_reply_free(&reply); + return status; +} +int aaudio_cmd_set_property(struct aaudio_device *a, aaudio_device_id_t devid, aaudio_object_id_t obj, + struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size, void *data, u64 data_size) +{ + CMD_DEF_SHARED_AND_SEND(aaudio_msg_write_set_property, devid, obj, prop, data, data_size, + qualifier, qualifier_size); + CMD_HNDL_REPLY_AND_FREE(aaudio_msg_read_set_property_response, &obj); +} +int aaudio_cmd_property_listener(struct aaudio_device *a, aaudio_device_id_t devid, aaudio_object_id_t obj, + struct aaudio_prop_addr prop) +{ + CMD_DEF_SHARED_AND_SEND(aaudio_msg_write_property_listener, devid, obj, prop); + CMD_HNDL_REPLY_AND_FREE(aaudio_msg_read_property_listener_response, &obj, &prop); +} +int aaudio_cmd_set_input_stream_address_ranges(struct aaudio_device *a, aaudio_device_id_t devid) +{ + CMD_DEF_SHARED_AND_SEND(aaudio_msg_write_set_input_stream_address_ranges, devid); + CMD_HNDL_REPLY_AND_FREE(aaudio_msg_read_set_input_stream_address_ranges_response); +} +int aaudio_cmd_get_input_stream_list(struct aaudio_device *a, struct aaudio_msg *buf, aaudio_device_id_t devid, + aaudio_object_id_t **str_l, u64 *str_cnt) +{ + CMD_DEF_SHARED_NO_REPLY_AND_SEND(aaudio_msg_write_get_input_stream_list, devid); + CMD_HNDL_REPLY_NO_FREE(aaudio_msg_read_get_input_stream_list_response, str_l, str_cnt); +} +int aaudio_cmd_get_output_stream_list(struct aaudio_device *a, struct aaudio_msg *buf, aaudio_device_id_t devid, + aaudio_object_id_t **str_l, u64 *str_cnt) +{ + CMD_DEF_SHARED_NO_REPLY_AND_SEND(aaudio_msg_write_get_output_stream_list, devid); + CMD_HNDL_REPLY_NO_FREE(aaudio_msg_read_get_output_stream_list_response, str_l, str_cnt); +} +int aaudio_cmd_set_remote_access(struct aaudio_device *a, u64 mode) +{ + CMD_DEF_SHARED_AND_SEND(aaudio_msg_write_set_remote_access, mode); + CMD_HNDL_REPLY_AND_FREE(aaudio_msg_read_set_remote_access_response); +} +int aaudio_cmd_get_device_list(struct aaudio_device *a, struct aaudio_msg *buf, + aaudio_device_id_t **dev_l, u64 *dev_cnt) +{ + CMD_DEF_SHARED_NO_REPLY_AND_SEND(aaudio_msg_write_get_device_list); + CMD_HNDL_REPLY_NO_FREE(aaudio_msg_read_get_device_list_response, dev_l, dev_cnt); +} \ No newline at end of file diff --git a/drivers/staging/apple-bce/audio/protocol.h b/drivers/staging/apple-bce/audio/protocol.h new file mode 100644 index 000000000000..3427486f3f57 --- /dev/null +++ b/drivers/staging/apple-bce/audio/protocol.h @@ -0,0 +1,147 @@ +#ifndef AAUDIO_PROTOCOL_H +#define AAUDIO_PROTOCOL_H + +#include + +struct aaudio_device; + +typedef u64 aaudio_device_id_t; +typedef u64 aaudio_object_id_t; + +struct aaudio_msg { + void *data; + size_t size; +}; + +struct __attribute__((packed)) aaudio_msg_header { + char tag[4]; + u8 type; + aaudio_device_id_t device_id; // Idk, use zero for commands? +}; +struct __attribute__((packed)) aaudio_msg_base { + u32 msg; + u32 status; +}; + +struct aaudio_prop_addr { + u32 scope; + u32 selector; + u32 element; +}; +#define AAUDIO_PROP(scope, sel, el) (struct aaudio_prop_addr) { scope, sel, el } + +enum { + AAUDIO_MSG_TYPE_COMMAND = 1, + AAUDIO_MSG_TYPE_RESPONSE = 2, + AAUDIO_MSG_TYPE_NOTIFICATION = 3 +}; + +enum { + AAUDIO_MSG_START_IO = 0, + AAUDIO_MSG_START_IO_RESPONSE = 1, + AAUDIO_MSG_STOP_IO = 2, + AAUDIO_MSG_STOP_IO_RESPONSE = 3, + AAUDIO_MSG_UPDATE_TIMESTAMP = 4, + AAUDIO_MSG_GET_PROPERTY = 7, + AAUDIO_MSG_GET_PROPERTY_RESPONSE = 8, + AAUDIO_MSG_SET_PROPERTY = 9, + AAUDIO_MSG_SET_PROPERTY_RESPONSE = 10, + AAUDIO_MSG_PROPERTY_LISTENER = 11, + AAUDIO_MSG_PROPERTY_LISTENER_RESPONSE = 12, + AAUDIO_MSG_PROPERTY_CHANGED = 13, + AAUDIO_MSG_SET_INPUT_STREAM_ADDRESS_RANGES = 18, + AAUDIO_MSG_SET_INPUT_STREAM_ADDRESS_RANGES_RESPONSE = 19, + AAUDIO_MSG_GET_INPUT_STREAM_LIST = 24, + AAUDIO_MSG_GET_INPUT_STREAM_LIST_RESPONSE = 25, + AAUDIO_MSG_GET_OUTPUT_STREAM_LIST = 26, + AAUDIO_MSG_GET_OUTPUT_STREAM_LIST_RESPONSE = 27, + AAUDIO_MSG_SET_REMOTE_ACCESS = 32, + AAUDIO_MSG_SET_REMOTE_ACCESS_RESPONSE = 33, + AAUDIO_MSG_UPDATE_TIMESTAMP_RESPONSE = 34, + + AAUDIO_MSG_NOTIFICATION_ALIVE = 100, + AAUDIO_MSG_GET_DEVICE_LIST = 101, + AAUDIO_MSG_GET_DEVICE_LIST_RESPONSE = 102, + AAUDIO_MSG_NOTIFICATION_BOOT = 104 +}; + +enum { + AAUDIO_REMOTE_ACCESS_OFF = 0, + AAUDIO_REMOTE_ACCESS_ON = 2 +}; + +enum { + AAUDIO_PROP_SCOPE_GLOBAL = 0x676c6f62, // 'glob' + AAUDIO_PROP_SCOPE_INPUT = 0x696e7074, // 'inpt' + AAUDIO_PROP_SCOPE_OUTPUT = 0x6f757470 // 'outp' +}; + +enum { + AAUDIO_PROP_UID = 0x75696420, // 'uid ' + AAUDIO_PROP_BOOL_VALUE = 0x6263766c, // 'bcvl' + AAUDIO_PROP_JACK_PLUGGED = 0x6a61636b, // 'jack' + AAUDIO_PROP_SEL_VOLUME = 0x64656176, // 'deav' + AAUDIO_PROP_LATENCY = 0x6c746e63, // 'ltnc' + AAUDIO_PROP_PHYS_FORMAT = 0x70667420 // 'pft ' +}; + +int aaudio_msg_read_base(struct aaudio_msg *msg, struct aaudio_msg_base *base); + +int aaudio_msg_read_start_io_response(struct aaudio_msg *msg); +int aaudio_msg_read_stop_io_response(struct aaudio_msg *msg); +int aaudio_msg_read_update_timestamp(struct aaudio_msg *msg, aaudio_device_id_t *devid, + u64 *timestamp, u64 *update_seed); +int aaudio_msg_read_get_property_response(struct aaudio_msg *msg, aaudio_object_id_t *obj, + struct aaudio_prop_addr *prop, void **data, u64 *data_size); +int aaudio_msg_read_set_property_response(struct aaudio_msg *msg, aaudio_object_id_t *obj); +int aaudio_msg_read_property_listener_response(struct aaudio_msg *msg,aaudio_object_id_t *obj, + struct aaudio_prop_addr *prop); +int aaudio_msg_read_property_changed(struct aaudio_msg *msg, aaudio_device_id_t *devid, aaudio_object_id_t *obj, + struct aaudio_prop_addr *prop); +int aaudio_msg_read_set_input_stream_address_ranges_response(struct aaudio_msg *msg); +int aaudio_msg_read_get_input_stream_list_response(struct aaudio_msg *msg, aaudio_object_id_t **str_l, u64 *str_cnt); +int aaudio_msg_read_get_output_stream_list_response(struct aaudio_msg *msg, aaudio_object_id_t **str_l, u64 *str_cnt); +int aaudio_msg_read_set_remote_access_response(struct aaudio_msg *msg); +int aaudio_msg_read_get_device_list_response(struct aaudio_msg *msg, aaudio_device_id_t **dev_l, u64 *dev_cnt); + +void aaudio_msg_write_start_io(struct aaudio_msg *msg, aaudio_device_id_t dev); +void aaudio_msg_write_stop_io(struct aaudio_msg *msg, aaudio_device_id_t dev); +void aaudio_msg_write_get_property(struct aaudio_msg *msg, aaudio_device_id_t dev, aaudio_object_id_t obj, + struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size); +void aaudio_msg_write_set_property(struct aaudio_msg *msg, aaudio_device_id_t dev, aaudio_object_id_t obj, + struct aaudio_prop_addr prop, void *data, u64 data_size, void *qualifier, u64 qualifier_size); +void aaudio_msg_write_property_listener(struct aaudio_msg *msg, aaudio_device_id_t dev, aaudio_object_id_t obj, + struct aaudio_prop_addr prop); +void aaudio_msg_write_set_input_stream_address_ranges(struct aaudio_msg *msg, aaudio_device_id_t devid); +void aaudio_msg_write_get_input_stream_list(struct aaudio_msg *msg, aaudio_device_id_t devid); +void aaudio_msg_write_get_output_stream_list(struct aaudio_msg *msg, aaudio_device_id_t devid); +void aaudio_msg_write_set_remote_access(struct aaudio_msg *msg, u64 mode); +void aaudio_msg_write_alive_notification(struct aaudio_msg *msg, u32 proto_ver, u32 msg_ver); +void aaudio_msg_write_update_timestamp_response(struct aaudio_msg *msg); +void aaudio_msg_write_get_device_list(struct aaudio_msg *msg); + + +int aaudio_cmd_start_io(struct aaudio_device *a, aaudio_device_id_t devid); +int aaudio_cmd_stop_io(struct aaudio_device *a, aaudio_device_id_t devid); +int aaudio_cmd_get_property(struct aaudio_device *a, struct aaudio_msg *buf, + aaudio_device_id_t devid, aaudio_object_id_t obj, + struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size, void **data, u64 *data_size); +int aaudio_cmd_get_primitive_property(struct aaudio_device *a, + aaudio_device_id_t devid, aaudio_object_id_t obj, + struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size, void *data, u64 data_size); +int aaudio_cmd_set_property(struct aaudio_device *a, aaudio_device_id_t devid, aaudio_object_id_t obj, + struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size, void *data, u64 data_size); +int aaudio_cmd_property_listener(struct aaudio_device *a, aaudio_device_id_t devid, aaudio_object_id_t obj, + struct aaudio_prop_addr prop); +int aaudio_cmd_set_input_stream_address_ranges(struct aaudio_device *a, aaudio_device_id_t devid); +int aaudio_cmd_get_input_stream_list(struct aaudio_device *a, struct aaudio_msg *buf, aaudio_device_id_t devid, + aaudio_object_id_t **str_l, u64 *str_cnt); +int aaudio_cmd_get_output_stream_list(struct aaudio_device *a, struct aaudio_msg *buf, aaudio_device_id_t devid, + aaudio_object_id_t **str_l, u64 *str_cnt); +int aaudio_cmd_set_remote_access(struct aaudio_device *a, u64 mode); +int aaudio_cmd_get_device_list(struct aaudio_device *a, struct aaudio_msg *buf, + aaudio_device_id_t **dev_l, u64 *dev_cnt); + + + +#endif //AAUDIO_PROTOCOL_H diff --git a/drivers/staging/apple-bce/audio/protocol_bce.c b/drivers/staging/apple-bce/audio/protocol_bce.c new file mode 100644 index 000000000000..28f2dfd44d67 --- /dev/null +++ b/drivers/staging/apple-bce/audio/protocol_bce.c @@ -0,0 +1,226 @@ +#include "protocol_bce.h" + +#include "audio.h" + +static void aaudio_bce_out_queue_completion(struct bce_queue_sq *sq); +static void aaudio_bce_in_queue_completion(struct bce_queue_sq *sq); +static int aaudio_bce_queue_init(struct aaudio_device *dev, struct aaudio_bce_queue *q, const char *name, int direction, + bce_sq_completion cfn); +void aaudio_bce_in_queue_submit_pending(struct aaudio_bce_queue *q, size_t count); + +int aaudio_bce_init(struct aaudio_device *dev) +{ + int status; + struct aaudio_bce *bce = &dev->bcem; + bce->cq = bce_create_cq(dev->bce, 0x80); + spin_lock_init(&bce->spinlock); + if (!bce->cq) + return -EINVAL; + if ((status = aaudio_bce_queue_init(dev, &bce->qout, "com.apple.BridgeAudio.IntelToARM", DMA_TO_DEVICE, + aaudio_bce_out_queue_completion))) { + return status; + } + if ((status = aaudio_bce_queue_init(dev, &bce->qin, "com.apple.BridgeAudio.ARMToIntel", DMA_FROM_DEVICE, + aaudio_bce_in_queue_completion))) { + return status; + } + aaudio_bce_in_queue_submit_pending(&bce->qin, bce->qin.el_count); + return 0; +} + +int aaudio_bce_queue_init(struct aaudio_device *dev, struct aaudio_bce_queue *q, const char *name, int direction, + bce_sq_completion cfn) +{ + q->cq = dev->bcem.cq; + q->el_size = AAUDIO_BCE_QUEUE_ELEMENT_SIZE; + q->el_count = AAUDIO_BCE_QUEUE_ELEMENT_COUNT; + /* NOTE: The Apple impl uses 0x80 as the queue size, however we use 21 (in fact 20) to simplify the impl */ + q->sq = bce_create_sq(dev->bce, q->cq, name, (u32) (q->el_count + 1), direction, cfn, dev); + if (!q->sq) + return -EINVAL; + + q->data = dma_alloc_coherent(&dev->bce->pci->dev, q->el_size * q->el_count, &q->dma_addr, GFP_KERNEL); + if (!q->data) { + bce_destroy_sq(dev->bce, q->sq); + return -EINVAL; + } + return 0; +} + +static void aaudio_send_create_tag(struct aaudio_bce *b, int *tagn, char tag[4]) +{ + char tag_zero[5]; + b->tag_num = (b->tag_num + 1) % AAUDIO_BCE_QUEUE_TAG_COUNT; + *tagn = b->tag_num; + snprintf(tag_zero, 5, "S%03d", b->tag_num); + *((u32 *) tag) = *((u32 *) tag_zero); +} + +int __aaudio_send_prepare(struct aaudio_bce *b, struct aaudio_send_ctx *ctx, char *tag) +{ + int status; + size_t index; + void *dptr; + struct aaudio_msg_header *header; + if ((status = bce_reserve_submission(b->qout.sq, &ctx->timeout))) + return status; + spin_lock_irqsave(&b->spinlock, ctx->irq_flags); + index = b->qout.data_tail; + dptr = (u8 *) b->qout.data + index * b->qout.el_size; + ctx->msg.data = dptr; + header = dptr; + if (tag) + *((u32 *) header->tag) = *((u32 *) tag); + else + aaudio_send_create_tag(b, &ctx->tag_n, header->tag); + return 0; +} + +void __aaudio_send(struct aaudio_bce *b, struct aaudio_send_ctx *ctx) +{ + struct bce_qe_submission *s = bce_next_submission(b->qout.sq); +#ifdef DEBUG + pr_debug("aaudio: Sending command data\n"); + print_hex_dump(KERN_DEBUG, "aaudio:OUT ", DUMP_PREFIX_NONE, 32, 1, ctx->msg.data, ctx->msg.size, true); +#endif + bce_set_submission_single(s, b->qout.dma_addr + (dma_addr_t) (ctx->msg.data - b->qout.data), ctx->msg.size); + bce_submit_to_device(b->qout.sq); + b->qout.data_tail = (b->qout.data_tail + 1) % b->qout.el_count; + spin_unlock_irqrestore(&b->spinlock, ctx->irq_flags); +} + +int __aaudio_send_cmd_sync(struct aaudio_bce *b, struct aaudio_send_ctx *ctx, struct aaudio_msg *reply) +{ + struct aaudio_bce_queue_entry ent; + DECLARE_COMPLETION_ONSTACK(cmpl); + ent.msg = reply; + ent.cmpl = &cmpl; + b->pending_entries[ctx->tag_n] = &ent; + __aaudio_send(b, ctx); /* unlocks the spinlock */ + ctx->timeout = wait_for_completion_timeout(&cmpl, ctx->timeout); + if (ctx->timeout == 0) { + /* Remove the pending queue entry; this will be normally handled by the completion route but + * during a timeout it won't */ + spin_lock_irqsave(&b->spinlock, ctx->irq_flags); + if (b->pending_entries[ctx->tag_n] == &ent) + b->pending_entries[ctx->tag_n] = NULL; + spin_unlock_irqrestore(&b->spinlock, ctx->irq_flags); + return -ETIMEDOUT; + } + return 0; +} + +static void aaudio_handle_reply(struct aaudio_bce *b, struct aaudio_msg *reply) +{ + const char *tag; + int tagn; + unsigned long irq_flags; + char tag_zero[5]; + struct aaudio_bce_queue_entry *entry; + + tag = ((struct aaudio_msg_header *) reply->data)->tag; + if (tag[0] != 'S') { + pr_err("aaudio_handle_reply: Unexpected tag: %.4s\n", tag); + return; + } + *((u32 *) tag_zero) = *((u32 *) tag); + tag_zero[4] = 0; + if (kstrtoint(&tag_zero[1], 10, &tagn)) { + pr_err("aaudio_handle_reply: Tag parse failed: %.4s\n", tag); + return; + } + + spin_lock_irqsave(&b->spinlock, irq_flags); + entry = b->pending_entries[tagn]; + if (entry) { + if (reply->size < entry->msg->size) + entry->msg->size = reply->size; + memcpy(entry->msg->data, reply->data, entry->msg->size); + complete(entry->cmpl); + + b->pending_entries[tagn] = NULL; + } else { + pr_err("aaudio_handle_reply: No queued item found for tag: %.4s\n", tag); + } + spin_unlock_irqrestore(&b->spinlock, irq_flags); +} + +static void aaudio_bce_out_queue_completion(struct bce_queue_sq *sq) +{ + while (bce_next_completion(sq)) { + //pr_info("aaudio: Send confirmed\n"); + bce_notify_submission_complete(sq); + } +} + +static void aaudio_bce_in_queue_handle_msg(struct aaudio_device *a, struct aaudio_msg *msg); + +static void aaudio_bce_in_queue_completion(struct bce_queue_sq *sq) +{ + struct aaudio_msg msg; + struct aaudio_device *dev = sq->userdata; + struct aaudio_bce_queue *q = &dev->bcem.qin; + struct bce_sq_completion_data *c; + size_t cnt = 0; + + mb(); + while ((c = bce_next_completion(sq))) { + msg.data = (u8 *) q->data + q->data_head * q->el_size; + msg.size = c->data_size; +#ifdef DEBUG + pr_debug("aaudio: Received command data %llx\n", c->data_size); + print_hex_dump(KERN_DEBUG, "aaudio:IN ", DUMP_PREFIX_NONE, 32, 1, msg.data, min(msg.size, 128UL), true); +#endif + aaudio_bce_in_queue_handle_msg(dev, &msg); + + q->data_head = (q->data_head + 1) % q->el_count; + + bce_notify_submission_complete(sq); + ++cnt; + } + aaudio_bce_in_queue_submit_pending(q, cnt); +} + +static void aaudio_bce_in_queue_handle_msg(struct aaudio_device *a, struct aaudio_msg *msg) +{ + struct aaudio_msg_header *header = (struct aaudio_msg_header *) msg->data; + if (msg->size < sizeof(struct aaudio_msg_header)) { + pr_err("aaudio: Msg size smaller than header (%lx)", msg->size); + return; + } + if (header->type == AAUDIO_MSG_TYPE_RESPONSE) { + aaudio_handle_reply(&a->bcem, msg); + } else if (header->type == AAUDIO_MSG_TYPE_COMMAND) { + aaudio_handle_command(a, msg); + } else if (header->type == AAUDIO_MSG_TYPE_NOTIFICATION) { + aaudio_handle_notification(a, msg); + } +} + +void aaudio_bce_in_queue_submit_pending(struct aaudio_bce_queue *q, size_t count) +{ + struct bce_qe_submission *s; + while (count--) { + if (bce_reserve_submission(q->sq, NULL)) { + pr_err("aaudio: Failed to reserve an event queue submission\n"); + break; + } + s = bce_next_submission(q->sq); + bce_set_submission_single(s, q->dma_addr + (dma_addr_t) (q->data_tail * q->el_size), q->el_size); + q->data_tail = (q->data_tail + 1) % q->el_count; + } + bce_submit_to_device(q->sq); +} + +struct aaudio_msg aaudio_reply_alloc(void) +{ + struct aaudio_msg ret; + ret.size = AAUDIO_BCE_QUEUE_ELEMENT_SIZE; + ret.data = kmalloc(ret.size, GFP_KERNEL); + return ret; +} + +void aaudio_reply_free(struct aaudio_msg *reply) +{ + kfree(reply->data); +} diff --git a/drivers/staging/apple-bce/audio/protocol_bce.h b/drivers/staging/apple-bce/audio/protocol_bce.h new file mode 100644 index 000000000000..14d26c05ddf9 --- /dev/null +++ b/drivers/staging/apple-bce/audio/protocol_bce.h @@ -0,0 +1,72 @@ +#ifndef AAUDIO_PROTOCOL_BCE_H +#define AAUDIO_PROTOCOL_BCE_H + +#include "protocol.h" +#include "../queue.h" + +#define AAUDIO_BCE_QUEUE_ELEMENT_SIZE 0x1000 +#define AAUDIO_BCE_QUEUE_ELEMENT_COUNT 20 + +#define AAUDIO_BCE_QUEUE_TAG_COUNT 1000 + +struct aaudio_device; + +struct aaudio_bce_queue_entry { + struct aaudio_msg *msg; + struct completion *cmpl; +}; +struct aaudio_bce_queue { + struct bce_queue_cq *cq; + struct bce_queue_sq *sq; + void *data; + dma_addr_t dma_addr; + size_t data_head, data_tail; + size_t el_size, el_count; +}; +struct aaudio_bce { + struct bce_queue_cq *cq; + struct aaudio_bce_queue qin; + struct aaudio_bce_queue qout; + int tag_num; + struct aaudio_bce_queue_entry *pending_entries[AAUDIO_BCE_QUEUE_TAG_COUNT]; + struct spinlock spinlock; +}; + +struct aaudio_send_ctx { + int status; + int tag_n; + unsigned long irq_flags; + struct aaudio_msg msg; + unsigned long timeout; +}; + +int aaudio_bce_init(struct aaudio_device *dev); +int __aaudio_send_prepare(struct aaudio_bce *b, struct aaudio_send_ctx *ctx, char *tag); +void __aaudio_send(struct aaudio_bce *b, struct aaudio_send_ctx *ctx); +int __aaudio_send_cmd_sync(struct aaudio_bce *b, struct aaudio_send_ctx *ctx, struct aaudio_msg *reply); + +#define aaudio_send_with_tag(a, ctx, tag, tout, fn, ...) ({ \ + (ctx)->timeout = msecs_to_jiffies(tout); \ + (ctx)->status = __aaudio_send_prepare(&(a)->bcem, (ctx), (tag)); \ + if (!(ctx)->status) { \ + fn(&(ctx)->msg, ##__VA_ARGS__); \ + __aaudio_send(&(a)->bcem, (ctx)); \ + } \ + (ctx)->status; \ +}) +#define aaudio_send(a, ctx, tout, fn, ...) aaudio_send_with_tag(a, ctx, NULL, tout, fn, ##__VA_ARGS__) + +#define aaudio_send_cmd_sync(a, ctx, reply, tout, fn, ...) ({ \ + (ctx)->timeout = msecs_to_jiffies(tout); \ + (ctx)->status = __aaudio_send_prepare(&(a)->bcem, (ctx), NULL); \ + if (!(ctx)->status) { \ + fn(&(ctx)->msg, ##__VA_ARGS__); \ + (ctx)->status = __aaudio_send_cmd_sync(&(a)->bcem, (ctx), (reply)); \ + } \ + (ctx)->status; \ +}) + +struct aaudio_msg aaudio_reply_alloc(void); +void aaudio_reply_free(struct aaudio_msg *reply); + +#endif //AAUDIO_PROTOCOL_BCE_H diff --git a/drivers/staging/apple-bce/mailbox.c b/drivers/staging/apple-bce/mailbox.c new file mode 100644 index 000000000000..e24bd35215c0 --- /dev/null +++ b/drivers/staging/apple-bce/mailbox.c @@ -0,0 +1,151 @@ +#include "mailbox.h" +#include +#include "apple_bce.h" + +#define REG_MBOX_OUT_BASE 0x820 +#define REG_MBOX_REPLY_COUNTER 0x108 +#define REG_MBOX_REPLY_BASE 0x810 +#define REG_TIMESTAMP_BASE 0xC000 + +#define BCE_MBOX_TIMEOUT_MS 200 + +void bce_mailbox_init(struct bce_mailbox *mb, void __iomem *reg_mb) +{ + mb->reg_mb = reg_mb; + init_completion(&mb->mb_completion); +} + +int bce_mailbox_send(struct bce_mailbox *mb, u64 msg, u64* recv) +{ + u32 __iomem *regb; + + if (atomic_cmpxchg(&mb->mb_status, 0, 1) != 0) { + return -EEXIST; // We don't support two messages at once + } + reinit_completion(&mb->mb_completion); + + pr_debug("bce_mailbox_send: %llx\n", msg); + regb = (u32*) ((u8*) mb->reg_mb + REG_MBOX_OUT_BASE); + iowrite32((u32) msg, regb); + iowrite32((u32) (msg >> 32), regb + 1); + iowrite32(0, regb + 2); + iowrite32(0, regb + 3); + + wait_for_completion_timeout(&mb->mb_completion, msecs_to_jiffies(BCE_MBOX_TIMEOUT_MS)); + if (atomic_read(&mb->mb_status) != 2) { // Didn't get the reply + atomic_set(&mb->mb_status, 0); + return -ETIMEDOUT; + } + + *recv = mb->mb_result; + pr_debug("bce_mailbox_send: reply %llx\n", *recv); + + atomic_set(&mb->mb_status, 0); + return 0; +} + +static int bce_mailbox_retrive_response(struct bce_mailbox *mb) +{ + u32 __iomem *regb; + u32 lo, hi; + int count, counter; + u32 res = ioread32((u8*) mb->reg_mb + REG_MBOX_REPLY_COUNTER); + count = (res >> 20) & 0xf; + counter = count; + pr_debug("bce_mailbox_retrive_response count=%i\n", count); + while (counter--) { + regb = (u32*) ((u8*) mb->reg_mb + REG_MBOX_REPLY_BASE); + lo = ioread32(regb); + hi = ioread32(regb + 1); + ioread32(regb + 2); + ioread32(regb + 3); + pr_debug("bce_mailbox_retrive_response %llx\n", ((u64) hi << 32) | lo); + mb->mb_result = ((u64) hi << 32) | lo; + } + return count > 0 ? 0 : -ENODATA; +} + +int bce_mailbox_handle_interrupt(struct bce_mailbox *mb) +{ + int status = bce_mailbox_retrive_response(mb); + if (!status) { + atomic_set(&mb->mb_status, 2); + complete(&mb->mb_completion); + } + return status; +} + +static void bc_send_timestamp(struct timer_list *tl); + +void bce_timestamp_init(struct bce_timestamp *ts, void __iomem *reg) +{ + u32 __iomem *regb; + + spin_lock_init(&ts->stop_sl); + ts->stopped = false; + + ts->reg = reg; + + regb = (u32*) ((u8*) ts->reg + REG_TIMESTAMP_BASE); + + ioread32(regb); + mb(); + + timer_setup(&ts->timer, bc_send_timestamp, 0); +} + +void bce_timestamp_start(struct bce_timestamp *ts, bool is_initial) +{ + unsigned long flags; + u32 __iomem *regb = (u32*) ((u8*) ts->reg + REG_TIMESTAMP_BASE); + + if (is_initial) { + iowrite32((u32) -4, regb + 2); + iowrite32((u32) -1, regb); + } else { + iowrite32((u32) -3, regb + 2); + iowrite32((u32) -1, regb); + } + + spin_lock_irqsave(&ts->stop_sl, flags); + ts->stopped = false; + spin_unlock_irqrestore(&ts->stop_sl, flags); + mod_timer(&ts->timer, jiffies + msecs_to_jiffies(150)); +} + +void bce_timestamp_stop(struct bce_timestamp *ts) +{ + unsigned long flags; + u32 __iomem *regb = (u32*) ((u8*) ts->reg + REG_TIMESTAMP_BASE); + + spin_lock_irqsave(&ts->stop_sl, flags); + ts->stopped = true; + spin_unlock_irqrestore(&ts->stop_sl, flags); + del_timer_sync(&ts->timer); + + iowrite32((u32) -2, regb + 2); + iowrite32((u32) -1, regb); +} + +static void bc_send_timestamp(struct timer_list *tl) +{ + struct bce_timestamp *ts; + unsigned long flags; + u32 __iomem *regb; + ktime_t bt; + + ts = container_of(tl, struct bce_timestamp, timer); + regb = (u32*) ((u8*) ts->reg + REG_TIMESTAMP_BASE); + local_irq_save(flags); + ioread32(regb + 2); + mb(); + bt = ktime_get_boottime(); + iowrite32((u32) bt, regb + 2); + iowrite32((u32) (bt >> 32), regb); + + spin_lock(&ts->stop_sl); + if (!ts->stopped) + mod_timer(&ts->timer, jiffies + msecs_to_jiffies(150)); + spin_unlock(&ts->stop_sl); + local_irq_restore(flags); +} \ No newline at end of file diff --git a/drivers/staging/apple-bce/mailbox.h b/drivers/staging/apple-bce/mailbox.h new file mode 100644 index 000000000000..f3323f95ba51 --- /dev/null +++ b/drivers/staging/apple-bce/mailbox.h @@ -0,0 +1,53 @@ +#ifndef BCE_MAILBOX_H +#define BCE_MAILBOX_H + +#include +#include +#include + +struct bce_mailbox { + void __iomem *reg_mb; + + atomic_t mb_status; // possible statuses: 0 (no msg), 1 (has active msg), 2 (got reply) + struct completion mb_completion; + uint64_t mb_result; +}; + +enum bce_message_type { + BCE_MB_REGISTER_COMMAND_SQ = 0x7, // to-device + BCE_MB_REGISTER_COMMAND_CQ = 0x8, // to-device + BCE_MB_REGISTER_COMMAND_QUEUE_REPLY = 0xB, // to-host + BCE_MB_SET_FW_PROTOCOL_VERSION = 0xC, // both + BCE_MB_SLEEP_NO_STATE = 0x14, // to-device + BCE_MB_RESTORE_NO_STATE = 0x15, // to-device + BCE_MB_SAVE_STATE_AND_SLEEP = 0x17, // to-device + BCE_MB_RESTORE_STATE_AND_WAKE = 0x18, // to-device + BCE_MB_SAVE_STATE_AND_SLEEP_FAILURE = 0x19, // from-device + BCE_MB_SAVE_RESTORE_STATE_COMPLETE = 0x1A, // from-device +}; + +#define BCE_MB_MSG(type, value) (((u64) (type) << 58) | ((value) & 0x3FFFFFFFFFFFFFFLL)) +#define BCE_MB_TYPE(v) ((u32) (v >> 58)) +#define BCE_MB_VALUE(v) (v & 0x3FFFFFFFFFFFFFFLL) + +void bce_mailbox_init(struct bce_mailbox *mb, void __iomem *reg_mb); + +int bce_mailbox_send(struct bce_mailbox *mb, u64 msg, u64* recv); + +int bce_mailbox_handle_interrupt(struct bce_mailbox *mb); + + +struct bce_timestamp { + void __iomem *reg; + struct timer_list timer; + struct spinlock stop_sl; + bool stopped; +}; + +void bce_timestamp_init(struct bce_timestamp *ts, void __iomem *reg); + +void bce_timestamp_start(struct bce_timestamp *ts, bool is_initial); + +void bce_timestamp_stop(struct bce_timestamp *ts); + +#endif //BCEDRIVER_MAILBOX_H diff --git a/drivers/staging/apple-bce/queue.c b/drivers/staging/apple-bce/queue.c new file mode 100644 index 000000000000..bc9cd3bc6f0c --- /dev/null +++ b/drivers/staging/apple-bce/queue.c @@ -0,0 +1,390 @@ +#include "queue.h" +#include "apple_bce.h" + +#define REG_DOORBELL_BASE 0x44000 + +struct bce_queue_cq *bce_alloc_cq(struct apple_bce_device *dev, int qid, u32 el_count) +{ + struct bce_queue_cq *q; + q = kzalloc(sizeof(struct bce_queue_cq), GFP_KERNEL); + q->qid = qid; + q->type = BCE_QUEUE_CQ; + q->el_count = el_count; + q->data = dma_alloc_coherent(&dev->pci->dev, el_count * sizeof(struct bce_qe_completion), + &q->dma_handle, GFP_KERNEL); + if (!q->data) { + pr_err("DMA queue memory alloc failed\n"); + kfree(q); + return NULL; + } + return q; +} + +void bce_get_cq_memcfg(struct bce_queue_cq *cq, struct bce_queue_memcfg *cfg) +{ + cfg->qid = (u16) cq->qid; + cfg->el_count = (u16) cq->el_count; + cfg->vector_or_cq = 0; + cfg->_pad = 0; + cfg->addr = cq->dma_handle; + cfg->length = cq->el_count * sizeof(struct bce_qe_completion); +} + +void bce_free_cq(struct apple_bce_device *dev, struct bce_queue_cq *cq) +{ + dma_free_coherent(&dev->pci->dev, cq->el_count * sizeof(struct bce_qe_completion), cq->data, cq->dma_handle); + kfree(cq); +} + +static void bce_handle_cq_completion(struct apple_bce_device *dev, struct bce_qe_completion *e, size_t *ce) +{ + struct bce_queue *target; + struct bce_queue_sq *target_sq; + struct bce_sq_completion_data *cmpl; + if (e->qid >= BCE_MAX_QUEUE_COUNT) { + pr_err("Device sent a response for qid (%u) >= BCE_MAX_QUEUE_COUNT\n", e->qid); + return; + } + target = dev->queues[e->qid]; + if (!target || target->type != BCE_QUEUE_SQ) { + pr_err("Device sent a response for qid (%u), which does not exist\n", e->qid); + return; + } + target_sq = (struct bce_queue_sq *) target; + if (target_sq->completion_tail != e->completion_index) { + pr_err("Completion index mismatch; this is likely going to make this driver unusable\n"); + return; + } + if (!target_sq->has_pending_completions) { + target_sq->has_pending_completions = true; + dev->int_sq_list[(*ce)++] = target_sq; + } + cmpl = &target_sq->completion_data[e->completion_index]; + cmpl->status = e->status; + cmpl->data_size = e->data_size; + cmpl->result = e->result; + wmb(); + target_sq->completion_tail = (target_sq->completion_tail + 1) % target_sq->el_count; +} + +void bce_handle_cq_completions(struct apple_bce_device *dev, struct bce_queue_cq *cq) +{ + size_t ce = 0; + struct bce_qe_completion *e; + struct bce_queue_sq *sq; + e = bce_cq_element(cq, cq->index); + if (!(e->flags & BCE_COMPLETION_FLAG_PENDING)) + return; + mb(); + while (true) { + e = bce_cq_element(cq, cq->index); + if (!(e->flags & BCE_COMPLETION_FLAG_PENDING)) + break; + // pr_info("apple-bce: compl: %i: %i %llx %llx", e->qid, e->status, e->data_size, e->result); + bce_handle_cq_completion(dev, e, &ce); + e->flags = 0; + cq->index = (cq->index + 1) % cq->el_count; + } + mb(); + iowrite32(cq->index, (u32 *) ((u8 *) dev->reg_mem_dma + REG_DOORBELL_BASE) + cq->qid); + while (ce) { + --ce; + sq = dev->int_sq_list[ce]; + sq->completion(sq); + sq->has_pending_completions = false; + } +} + + +struct bce_queue_sq *bce_alloc_sq(struct apple_bce_device *dev, int qid, u32 el_size, u32 el_count, + bce_sq_completion compl, void *userdata) +{ + struct bce_queue_sq *q; + q = kzalloc(sizeof(struct bce_queue_sq), GFP_KERNEL); + q->qid = qid; + q->type = BCE_QUEUE_SQ; + q->el_size = el_size; + q->el_count = el_count; + q->data = dma_alloc_coherent(&dev->pci->dev, el_count * el_size, + &q->dma_handle, GFP_KERNEL); + q->completion = compl; + q->userdata = userdata; + q->completion_data = kzalloc(sizeof(struct bce_sq_completion_data) * el_count, GFP_KERNEL); + q->reg_mem_dma = dev->reg_mem_dma; + atomic_set(&q->available_commands, el_count - 1); + init_completion(&q->available_command_completion); + atomic_set(&q->available_command_completion_waiting_count, 0); + if (!q->data) { + pr_err("DMA queue memory alloc failed\n"); + kfree(q); + return NULL; + } + return q; +} + +void bce_get_sq_memcfg(struct bce_queue_sq *sq, struct bce_queue_cq *cq, struct bce_queue_memcfg *cfg) +{ + cfg->qid = (u16) sq->qid; + cfg->el_count = (u16) sq->el_count; + cfg->vector_or_cq = (u16) cq->qid; + cfg->_pad = 0; + cfg->addr = sq->dma_handle; + cfg->length = sq->el_count * sq->el_size; +} + +void bce_free_sq(struct apple_bce_device *dev, struct bce_queue_sq *sq) +{ + dma_free_coherent(&dev->pci->dev, sq->el_count * sq->el_size, sq->data, sq->dma_handle); + kfree(sq); +} + +int bce_reserve_submission(struct bce_queue_sq *sq, unsigned long *timeout) +{ + while (atomic_dec_if_positive(&sq->available_commands) < 0) { + if (!timeout || !*timeout) + return -EAGAIN; + atomic_inc(&sq->available_command_completion_waiting_count); + *timeout = wait_for_completion_timeout(&sq->available_command_completion, *timeout); + if (!*timeout) { + if (atomic_dec_if_positive(&sq->available_command_completion_waiting_count) < 0) + try_wait_for_completion(&sq->available_command_completion); /* consume the pending completion */ + } + } + return 0; +} + +void bce_cancel_submission_reservation(struct bce_queue_sq *sq) +{ + atomic_inc(&sq->available_commands); +} + +void *bce_next_submission(struct bce_queue_sq *sq) +{ + void *ret = bce_sq_element(sq, sq->tail); + sq->tail = (sq->tail + 1) % sq->el_count; + return ret; +} + +void bce_submit_to_device(struct bce_queue_sq *sq) +{ + mb(); + iowrite32(sq->tail, (u32 *) ((u8 *) sq->reg_mem_dma + REG_DOORBELL_BASE) + sq->qid); +} + +void bce_notify_submission_complete(struct bce_queue_sq *sq) +{ + sq->head = (sq->head + 1) % sq->el_count; + atomic_inc(&sq->available_commands); + if (atomic_dec_if_positive(&sq->available_command_completion_waiting_count) >= 0) { + complete(&sq->available_command_completion); + } +} + +void bce_set_submission_single(struct bce_qe_submission *element, dma_addr_t addr, size_t size) +{ + element->addr = addr; + element->length = size; + element->segl_addr = element->segl_length = 0; +} + +static void bce_cmdq_completion(struct bce_queue_sq *q); + +struct bce_queue_cmdq *bce_alloc_cmdq(struct apple_bce_device *dev, int qid, u32 el_count) +{ + struct bce_queue_cmdq *q; + q = kzalloc(sizeof(struct bce_queue_cmdq), GFP_KERNEL); + q->sq = bce_alloc_sq(dev, qid, BCE_CMD_SIZE, el_count, bce_cmdq_completion, q); + if (!q->sq) { + kfree(q); + return NULL; + } + spin_lock_init(&q->lck); + q->tres = kzalloc(sizeof(struct bce_queue_cmdq_result_el*) * el_count, GFP_KERNEL); + if (!q->tres) { + kfree(q); + return NULL; + } + return q; +} + +void bce_free_cmdq(struct apple_bce_device *dev, struct bce_queue_cmdq *cmdq) +{ + bce_free_sq(dev, cmdq->sq); + kfree(cmdq->tres); + kfree(cmdq); +} + +void bce_cmdq_completion(struct bce_queue_sq *q) +{ + struct bce_queue_cmdq_result_el *el; + struct bce_queue_cmdq *cmdq = q->userdata; + struct bce_sq_completion_data *result; + + spin_lock(&cmdq->lck); + while ((result = bce_next_completion(q))) { + el = cmdq->tres[cmdq->sq->head]; + if (el) { + el->result = result->result; + el->status = result->status; + mb(); + complete(&el->cmpl); + } else { + pr_err("apple-bce: Unexpected command queue completion\n"); + } + cmdq->tres[cmdq->sq->head] = NULL; + bce_notify_submission_complete(q); + } + spin_unlock(&cmdq->lck); +} + +static __always_inline void *bce_cmd_start(struct bce_queue_cmdq *cmdq, struct bce_queue_cmdq_result_el *res) +{ + void *ret; + unsigned long timeout; + init_completion(&res->cmpl); + mb(); + + timeout = msecs_to_jiffies(1000L * 60 * 5); /* wait for up to ~5 minutes */ + if (bce_reserve_submission(cmdq->sq, &timeout)) + return NULL; + + spin_lock(&cmdq->lck); + cmdq->tres[cmdq->sq->tail] = res; + ret = bce_next_submission(cmdq->sq); + return ret; +} + +static __always_inline void bce_cmd_finish(struct bce_queue_cmdq *cmdq, struct bce_queue_cmdq_result_el *res) +{ + bce_submit_to_device(cmdq->sq); + spin_unlock(&cmdq->lck); + + wait_for_completion(&res->cmpl); + mb(); +} + +u32 bce_cmd_register_queue(struct bce_queue_cmdq *cmdq, struct bce_queue_memcfg *cfg, const char *name, bool isdirout) +{ + struct bce_queue_cmdq_result_el res; + struct bce_cmdq_register_memory_queue_cmd *cmd = bce_cmd_start(cmdq, &res); + if (!cmd) + return (u32) -1; + cmd->cmd = BCE_CMD_REGISTER_MEMORY_QUEUE; + cmd->flags = (u16) ((name ? 2 : 0) | (isdirout ? 1 : 0)); + cmd->qid = cfg->qid; + cmd->el_count = cfg->el_count; + cmd->vector_or_cq = cfg->vector_or_cq; + memset(cmd->name, 0, sizeof(cmd->name)); + if (name) { + cmd->name_len = (u16) min(strlen(name), (size_t) sizeof(cmd->name)); + memcpy(cmd->name, name, cmd->name_len); + } else { + cmd->name_len = 0; + } + cmd->addr = cfg->addr; + cmd->length = cfg->length; + + bce_cmd_finish(cmdq, &res); + return res.status; +} + +u32 bce_cmd_unregister_memory_queue(struct bce_queue_cmdq *cmdq, u16 qid) +{ + struct bce_queue_cmdq_result_el res; + struct bce_cmdq_simple_memory_queue_cmd *cmd = bce_cmd_start(cmdq, &res); + if (!cmd) + return (u32) -1; + cmd->cmd = BCE_CMD_UNREGISTER_MEMORY_QUEUE; + cmd->flags = 0; + cmd->qid = qid; + bce_cmd_finish(cmdq, &res); + return res.status; +} + +u32 bce_cmd_flush_memory_queue(struct bce_queue_cmdq *cmdq, u16 qid) +{ + struct bce_queue_cmdq_result_el res; + struct bce_cmdq_simple_memory_queue_cmd *cmd = bce_cmd_start(cmdq, &res); + if (!cmd) + return (u32) -1; + cmd->cmd = BCE_CMD_FLUSH_MEMORY_QUEUE; + cmd->flags = 0; + cmd->qid = qid; + bce_cmd_finish(cmdq, &res); + return res.status; +} + + +struct bce_queue_cq *bce_create_cq(struct apple_bce_device *dev, u32 el_count) +{ + struct bce_queue_cq *cq; + struct bce_queue_memcfg cfg; + int qid = ida_simple_get(&dev->queue_ida, BCE_QUEUE_USER_MIN, BCE_QUEUE_USER_MAX, GFP_KERNEL); + if (qid < 0) + return NULL; + cq = bce_alloc_cq(dev, qid, el_count); + if (!cq) + return NULL; + bce_get_cq_memcfg(cq, &cfg); + if (bce_cmd_register_queue(dev->cmd_cmdq, &cfg, NULL, false) != 0) { + pr_err("apple-bce: CQ registration failed (%i)", qid); + bce_free_cq(dev, cq); + ida_simple_remove(&dev->queue_ida, (uint) qid); + return NULL; + } + dev->queues[qid] = (struct bce_queue *) cq; + return cq; +} + +struct bce_queue_sq *bce_create_sq(struct apple_bce_device *dev, struct bce_queue_cq *cq, const char *name, u32 el_count, + int direction, bce_sq_completion compl, void *userdata) +{ + struct bce_queue_sq *sq; + struct bce_queue_memcfg cfg; + int qid; + if (cq == NULL) + return NULL; /* cq can not be null */ + if (name == NULL) + return NULL; /* name can not be null */ + if (direction != DMA_TO_DEVICE && direction != DMA_FROM_DEVICE) + return NULL; /* unsupported direction */ + qid = ida_simple_get(&dev->queue_ida, BCE_QUEUE_USER_MIN, BCE_QUEUE_USER_MAX, GFP_KERNEL); + if (qid < 0) + return NULL; + sq = bce_alloc_sq(dev, qid, sizeof(struct bce_qe_submission), el_count, compl, userdata); + if (!sq) + return NULL; + bce_get_sq_memcfg(sq, cq, &cfg); + if (bce_cmd_register_queue(dev->cmd_cmdq, &cfg, name, direction != DMA_FROM_DEVICE) != 0) { + pr_err("apple-bce: SQ registration failed (%i)", qid); + bce_free_sq(dev, sq); + ida_simple_remove(&dev->queue_ida, (uint) qid); + return NULL; + } + spin_lock(&dev->queues_lock); + dev->queues[qid] = (struct bce_queue *) sq; + spin_unlock(&dev->queues_lock); + return sq; +} + +void bce_destroy_cq(struct apple_bce_device *dev, struct bce_queue_cq *cq) +{ + if (!dev->is_being_removed && bce_cmd_unregister_memory_queue(dev->cmd_cmdq, (u16) cq->qid)) + pr_err("apple-bce: CQ unregister failed"); + spin_lock(&dev->queues_lock); + dev->queues[cq->qid] = NULL; + spin_unlock(&dev->queues_lock); + ida_simple_remove(&dev->queue_ida, (uint) cq->qid); + bce_free_cq(dev, cq); +} + +void bce_destroy_sq(struct apple_bce_device *dev, struct bce_queue_sq *sq) +{ + if (!dev->is_being_removed && bce_cmd_unregister_memory_queue(dev->cmd_cmdq, (u16) sq->qid)) + pr_err("apple-bce: CQ unregister failed"); + spin_lock(&dev->queues_lock); + dev->queues[sq->qid] = NULL; + spin_unlock(&dev->queues_lock); + ida_simple_remove(&dev->queue_ida, (uint) sq->qid); + bce_free_sq(dev, sq); +} \ No newline at end of file diff --git a/drivers/staging/apple-bce/queue.h b/drivers/staging/apple-bce/queue.h new file mode 100644 index 000000000000..8368ac5dfca8 --- /dev/null +++ b/drivers/staging/apple-bce/queue.h @@ -0,0 +1,177 @@ +#ifndef BCE_QUEUE_H +#define BCE_QUEUE_H + +#include +#include + +#define BCE_CMD_SIZE 0x40 + +struct apple_bce_device; + +enum bce_queue_type { + BCE_QUEUE_CQ, BCE_QUEUE_SQ +}; +struct bce_queue { + int qid; + int type; +}; +struct bce_queue_cq { + int qid; + int type; + u32 el_count; + dma_addr_t dma_handle; + void *data; + + u32 index; +}; +struct bce_queue_sq; +typedef void (*bce_sq_completion)(struct bce_queue_sq *q); +struct bce_sq_completion_data { + u32 status; + u64 data_size; + u64 result; +}; +struct bce_queue_sq { + int qid; + int type; + u32 el_size; + u32 el_count; + dma_addr_t dma_handle; + void *data; + void *userdata; + void __iomem *reg_mem_dma; + + atomic_t available_commands; + struct completion available_command_completion; + atomic_t available_command_completion_waiting_count; + u32 head, tail; + + u32 completion_cidx, completion_tail; + struct bce_sq_completion_data *completion_data; + bool has_pending_completions; + bce_sq_completion completion; +}; + +struct bce_queue_cmdq_result_el { + struct completion cmpl; + u32 status; + u64 result; +}; +struct bce_queue_cmdq { + struct bce_queue_sq *sq; + struct spinlock lck; + struct bce_queue_cmdq_result_el **tres; +}; + +struct bce_queue_memcfg { + u16 qid; + u16 el_count; + u16 vector_or_cq; + u16 _pad; + u64 addr; + u64 length; +}; + +enum bce_qe_completion_status { + BCE_COMPLETION_SUCCESS = 0, + BCE_COMPLETION_ERROR = 1, + BCE_COMPLETION_ABORTED = 2, + BCE_COMPLETION_NO_SPACE = 3, + BCE_COMPLETION_OVERRUN = 4 +}; +enum bce_qe_completion_flags { + BCE_COMPLETION_FLAG_PENDING = 0x8000 +}; +struct bce_qe_completion { + u64 result; + u64 data_size; + u16 qid; + u16 completion_index; + u16 status; // bce_qe_completion_status + u16 flags; // bce_qe_completion_flags +}; + +struct bce_qe_submission { + u64 length; + u64 addr; + + u64 segl_addr; + u64 segl_length; +}; + +enum bce_cmdq_command { + BCE_CMD_REGISTER_MEMORY_QUEUE = 0x20, + BCE_CMD_UNREGISTER_MEMORY_QUEUE = 0x30, + BCE_CMD_FLUSH_MEMORY_QUEUE = 0x40, + BCE_CMD_SET_MEMORY_QUEUE_PROPERTY = 0x50 +}; +struct bce_cmdq_simple_memory_queue_cmd { + u16 cmd; // bce_cmdq_command + u16 flags; + u16 qid; +}; +struct bce_cmdq_register_memory_queue_cmd { + u16 cmd; // bce_cmdq_command + u16 flags; + u16 qid; + u16 _pad; + u16 el_count; + u16 vector_or_cq; + u16 _pad2; + u16 name_len; + char name[0x20]; + u64 addr; + u64 length; +}; + +static __always_inline void *bce_sq_element(struct bce_queue_sq *q, int i) { + return (void *) ((u8 *) q->data + q->el_size * i); +} +static __always_inline void *bce_cq_element(struct bce_queue_cq *q, int i) { + return (void *) ((struct bce_qe_completion *) q->data + i); +} + +static __always_inline struct bce_sq_completion_data *bce_next_completion(struct bce_queue_sq *sq) { + struct bce_sq_completion_data *res; + rmb(); + if (sq->completion_cidx == sq->completion_tail) + return NULL; + res = &sq->completion_data[sq->completion_cidx]; + sq->completion_cidx = (sq->completion_cidx + 1) % sq->el_count; + return res; +} + +struct bce_queue_cq *bce_alloc_cq(struct apple_bce_device *dev, int qid, u32 el_count); +void bce_get_cq_memcfg(struct bce_queue_cq *cq, struct bce_queue_memcfg *cfg); +void bce_free_cq(struct apple_bce_device *dev, struct bce_queue_cq *cq); +void bce_handle_cq_completions(struct apple_bce_device *dev, struct bce_queue_cq *cq); + +struct bce_queue_sq *bce_alloc_sq(struct apple_bce_device *dev, int qid, u32 el_size, u32 el_count, + bce_sq_completion compl, void *userdata); +void bce_get_sq_memcfg(struct bce_queue_sq *sq, struct bce_queue_cq *cq, struct bce_queue_memcfg *cfg); +void bce_free_sq(struct apple_bce_device *dev, struct bce_queue_sq *sq); +int bce_reserve_submission(struct bce_queue_sq *sq, unsigned long *timeout); +void bce_cancel_submission_reservation(struct bce_queue_sq *sq); +void *bce_next_submission(struct bce_queue_sq *sq); +void bce_submit_to_device(struct bce_queue_sq *sq); +void bce_notify_submission_complete(struct bce_queue_sq *sq); + +void bce_set_submission_single(struct bce_qe_submission *element, dma_addr_t addr, size_t size); + +struct bce_queue_cmdq *bce_alloc_cmdq(struct apple_bce_device *dev, int qid, u32 el_count); +void bce_free_cmdq(struct apple_bce_device *dev, struct bce_queue_cmdq *cmdq); + +u32 bce_cmd_register_queue(struct bce_queue_cmdq *cmdq, struct bce_queue_memcfg *cfg, const char *name, bool isdirout); +u32 bce_cmd_unregister_memory_queue(struct bce_queue_cmdq *cmdq, u16 qid); +u32 bce_cmd_flush_memory_queue(struct bce_queue_cmdq *cmdq, u16 qid); + + +/* User API - Creates and registers the queue */ + +struct bce_queue_cq *bce_create_cq(struct apple_bce_device *dev, u32 el_count); +struct bce_queue_sq *bce_create_sq(struct apple_bce_device *dev, struct bce_queue_cq *cq, const char *name, u32 el_count, + int direction, bce_sq_completion compl, void *userdata); +void bce_destroy_cq(struct apple_bce_device *dev, struct bce_queue_cq *cq); +void bce_destroy_sq(struct apple_bce_device *dev, struct bce_queue_sq *sq); + +#endif //BCEDRIVER_MAILBOX_H diff --git a/drivers/staging/apple-bce/queue_dma.c b/drivers/staging/apple-bce/queue_dma.c new file mode 100644 index 000000000000..b236613285c0 --- /dev/null +++ b/drivers/staging/apple-bce/queue_dma.c @@ -0,0 +1,220 @@ +#include "queue_dma.h" +#include +#include +#include "queue.h" + +static int bce_alloc_scatterlist_from_vm(struct sg_table *tbl, void *data, size_t len); +static struct bce_segment_list_element_hostinfo *bce_map_segment_list( + struct device *dev, struct scatterlist *pages, int pagen); +static void bce_unmap_segement_list(struct device *dev, struct bce_segment_list_element_hostinfo *list); + +int bce_map_dma_buffer(struct device *dev, struct bce_dma_buffer *buf, struct sg_table scatterlist, + enum dma_data_direction dir) +{ + int cnt; + + buf->direction = dir; + buf->scatterlist = scatterlist; + buf->seglist_hostinfo = NULL; + + cnt = dma_map_sg(dev, buf->scatterlist.sgl, buf->scatterlist.nents, dir); + if (cnt != buf->scatterlist.nents) { + pr_err("apple-bce: DMA scatter list mapping returned an unexpected count: %i\n", cnt); + dma_unmap_sg(dev, buf->scatterlist.sgl, buf->scatterlist.nents, dir); + return -EIO; + } + if (cnt == 1) + return 0; + + buf->seglist_hostinfo = bce_map_segment_list(dev, buf->scatterlist.sgl, buf->scatterlist.nents); + if (!buf->seglist_hostinfo) { + pr_err("apple-bce: Creating segment list failed\n"); + dma_unmap_sg(dev, buf->scatterlist.sgl, buf->scatterlist.nents, dir); + return -EIO; + } + return 0; +} + +int bce_map_dma_buffer_vm(struct device *dev, struct bce_dma_buffer *buf, void *data, size_t len, + enum dma_data_direction dir) +{ + int status; + struct sg_table scatterlist; + if ((status = bce_alloc_scatterlist_from_vm(&scatterlist, data, len))) + return status; + if ((status = bce_map_dma_buffer(dev, buf, scatterlist, dir))) { + sg_free_table(&scatterlist); + return status; + } + return 0; +} + +int bce_map_dma_buffer_km(struct device *dev, struct bce_dma_buffer *buf, void *data, size_t len, + enum dma_data_direction dir) +{ + /* Kernel memory is continuous which is great for us. */ + int status; + struct sg_table scatterlist; + if ((status = sg_alloc_table(&scatterlist, 1, GFP_KERNEL))) { + sg_free_table(&scatterlist); + return status; + } + sg_set_buf(scatterlist.sgl, data, (uint) len); + if ((status = bce_map_dma_buffer(dev, buf, scatterlist, dir))) { + sg_free_table(&scatterlist); + return status; + } + return 0; +} + +void bce_unmap_dma_buffer(struct device *dev, struct bce_dma_buffer *buf) +{ + dma_unmap_sg(dev, buf->scatterlist.sgl, buf->scatterlist.nents, buf->direction); + bce_unmap_segement_list(dev, buf->seglist_hostinfo); +} + + +static int bce_alloc_scatterlist_from_vm(struct sg_table *tbl, void *data, size_t len) +{ + int status, i; + struct page **pages; + size_t off, start_page, end_page, page_count; + off = (size_t) data % PAGE_SIZE; + start_page = (size_t) data / PAGE_SIZE; + end_page = ((size_t) data + len - 1) / PAGE_SIZE; + page_count = end_page - start_page + 1; + + if (page_count > PAGE_SIZE / sizeof(struct page *)) + pages = vmalloc(page_count * sizeof(struct page *)); + else + pages = kmalloc(page_count * sizeof(struct page *), GFP_KERNEL); + + for (i = 0; i < page_count; i++) + pages[i] = vmalloc_to_page((void *) ((start_page + i) * PAGE_SIZE)); + + if ((status = sg_alloc_table_from_pages(tbl, pages, page_count, (unsigned int) off, len, GFP_KERNEL))) { + sg_free_table(tbl); + } + + if (page_count > PAGE_SIZE / sizeof(struct page *)) + vfree(pages); + else + kfree(pages); + return status; +} + +#define BCE_ELEMENTS_PER_PAGE ((PAGE_SIZE - sizeof(struct bce_segment_list_header)) \ + / sizeof(struct bce_segment_list_element)) +#define BCE_ELEMENTS_PER_ADDITIONAL_PAGE (PAGE_SIZE / sizeof(struct bce_segment_list_element)) + +static struct bce_segment_list_element_hostinfo *bce_map_segment_list( + struct device *dev, struct scatterlist *pages, int pagen) +{ + size_t ptr, pptr = 0; + struct bce_segment_list_header theader; /* a temp header, to store the initial seg */ + struct bce_segment_list_header *header; + struct bce_segment_list_element *el, *el_end; + struct bce_segment_list_element_hostinfo *out, *pout, *out_root; + struct scatterlist *sg; + int i; + header = &theader; + out = out_root = NULL; + el = el_end = NULL; + for_each_sg(pages, sg, pagen, i) { + if (el >= el_end) { + /* allocate a new page, this will be also done for the first element */ + ptr = __get_free_page(GFP_KERNEL); + if (pptr && ptr == pptr + PAGE_SIZE) { + out->page_count++; + header->element_count += BCE_ELEMENTS_PER_ADDITIONAL_PAGE; + el_end += BCE_ELEMENTS_PER_ADDITIONAL_PAGE; + } else { + header = (void *) ptr; + header->element_count = BCE_ELEMENTS_PER_PAGE; + header->data_size = 0; + header->next_segl_addr = 0; + header->next_segl_length = 0; + el = (void *) (header + 1); + el_end = el + BCE_ELEMENTS_PER_PAGE; + + if (out) { + out->next = kmalloc(sizeof(struct bce_segment_list_element_hostinfo), GFP_KERNEL); + out = out->next; + } else { + out_root = out = kmalloc(sizeof(struct bce_segment_list_element_hostinfo), GFP_KERNEL); + } + out->page_start = (void *) ptr; + out->page_count = 1; + out->dma_start = DMA_MAPPING_ERROR; + out->next = NULL; + } + pptr = ptr; + } + el->addr = sg->dma_address; + el->length = sg->length; + header->data_size += el->length; + } + + /* DMA map */ + out = out_root; + pout = NULL; + while (out) { + out->dma_start = dma_map_single(dev, out->page_start, out->page_count * PAGE_SIZE, DMA_TO_DEVICE); + if (dma_mapping_error(dev, out->dma_start)) + goto error; + if (pout) { + header = pout->page_start; + header->next_segl_addr = out->dma_start; + header->next_segl_length = out->page_count * PAGE_SIZE; + } + pout = out; + out = out->next; + } + return out_root; + + error: + bce_unmap_segement_list(dev, out_root); + return NULL; +} + +static void bce_unmap_segement_list(struct device *dev, struct bce_segment_list_element_hostinfo *list) +{ + struct bce_segment_list_element_hostinfo *next; + while (list) { + if (list->dma_start != DMA_MAPPING_ERROR) + dma_unmap_single(dev, list->dma_start, list->page_count * PAGE_SIZE, DMA_TO_DEVICE); + next = list->next; + kfree(list); + list = next; + } +} + +int bce_set_submission_buf(struct bce_qe_submission *element, struct bce_dma_buffer *buf, size_t offset, size_t length) +{ + struct bce_segment_list_element_hostinfo *seg; + struct bce_segment_list_header *seg_header; + + seg = buf->seglist_hostinfo; + if (!seg) { + element->addr = buf->scatterlist.sgl->dma_address + offset; + element->length = length; + element->segl_addr = 0; + element->segl_length = 0; + return 0; + } + + while (seg) { + seg_header = seg->page_start; + if (offset <= seg_header->data_size) + break; + offset -= seg_header->data_size; + seg = seg->next; + } + if (!seg) + return -EINVAL; + element->addr = offset; + element->length = buf->scatterlist.sgl->dma_length; + element->segl_addr = seg->dma_start; + element->segl_length = seg->page_count * PAGE_SIZE; + return 0; +} \ No newline at end of file diff --git a/drivers/staging/apple-bce/queue_dma.h b/drivers/staging/apple-bce/queue_dma.h new file mode 100644 index 000000000000..f8a57e50e7a3 --- /dev/null +++ b/drivers/staging/apple-bce/queue_dma.h @@ -0,0 +1,50 @@ +#ifndef BCE_QUEUE_DMA_H +#define BCE_QUEUE_DMA_H + +#include + +struct bce_qe_submission; + +struct bce_segment_list_header { + u64 element_count; + u64 data_size; + + u64 next_segl_addr; + u64 next_segl_length; +}; +struct bce_segment_list_element { + u64 addr; + u64 length; +}; + +struct bce_segment_list_element_hostinfo { + struct bce_segment_list_element_hostinfo *next; + void *page_start; + size_t page_count; + dma_addr_t dma_start; +}; + + +struct bce_dma_buffer { + enum dma_data_direction direction; + struct sg_table scatterlist; + struct bce_segment_list_element_hostinfo *seglist_hostinfo; +}; + +/* NOTE: Takes ownership of the sg_table if it succeeds. Ownership is not transferred on failure. */ +int bce_map_dma_buffer(struct device *dev, struct bce_dma_buffer *buf, struct sg_table scatterlist, + enum dma_data_direction dir); + +/* Creates a buffer from virtual memory (vmalloc) */ +int bce_map_dma_buffer_vm(struct device *dev, struct bce_dma_buffer *buf, void *data, size_t len, + enum dma_data_direction dir); + +/* Creates a buffer from kernel memory (kmalloc) */ +int bce_map_dma_buffer_km(struct device *dev, struct bce_dma_buffer *buf, void *data, size_t len, + enum dma_data_direction dir); + +void bce_unmap_dma_buffer(struct device *dev, struct bce_dma_buffer *buf); + +int bce_set_submission_buf(struct bce_qe_submission *element, struct bce_dma_buffer *buf, size_t offset, size_t length); + +#endif //BCE_QUEUE_DMA_H diff --git a/drivers/staging/apple-bce/vhci/command.h b/drivers/staging/apple-bce/vhci/command.h new file mode 100644 index 000000000000..26619e0bccfa --- /dev/null +++ b/drivers/staging/apple-bce/vhci/command.h @@ -0,0 +1,204 @@ +#ifndef BCE_VHCI_COMMAND_H +#define BCE_VHCI_COMMAND_H + +#include "queue.h" +#include +#include + +#define BCE_VHCI_CMD_TIMEOUT_SHORT msecs_to_jiffies(2000) +#define BCE_VHCI_CMD_TIMEOUT_LONG msecs_to_jiffies(30000) + +#define BCE_VHCI_BULK_MAX_ACTIVE_URBS_POW2 2 +#define BCE_VHCI_BULK_MAX_ACTIVE_URBS (1 << BCE_VHCI_BULK_MAX_ACTIVE_URBS_POW2) + +typedef u8 bce_vhci_port_t; +typedef u8 bce_vhci_device_t; + +enum bce_vhci_command { + BCE_VHCI_CMD_CONTROLLER_ENABLE = 1, + BCE_VHCI_CMD_CONTROLLER_DISABLE = 2, + BCE_VHCI_CMD_CONTROLLER_START = 3, + BCE_VHCI_CMD_CONTROLLER_PAUSE = 4, + + BCE_VHCI_CMD_PORT_POWER_ON = 0x10, + BCE_VHCI_CMD_PORT_POWER_OFF = 0x11, + BCE_VHCI_CMD_PORT_RESUME = 0x12, + BCE_VHCI_CMD_PORT_SUSPEND = 0x13, + BCE_VHCI_CMD_PORT_RESET = 0x14, + BCE_VHCI_CMD_PORT_DISABLE = 0x15, + BCE_VHCI_CMD_PORT_STATUS = 0x16, + + BCE_VHCI_CMD_DEVICE_CREATE = 0x30, + BCE_VHCI_CMD_DEVICE_DESTROY = 0x31, + + BCE_VHCI_CMD_ENDPOINT_CREATE = 0x40, + BCE_VHCI_CMD_ENDPOINT_DESTROY = 0x41, + BCE_VHCI_CMD_ENDPOINT_SET_STATE = 0x42, + BCE_VHCI_CMD_ENDPOINT_RESET = 0x44, + + /* Device to host only */ + BCE_VHCI_CMD_ENDPOINT_REQUEST_STATE = 0x43, + BCE_VHCI_CMD_TRANSFER_REQUEST = 0x1000, + BCE_VHCI_CMD_CONTROL_TRANSFER_STATUS = 0x1005 +}; + +enum bce_vhci_endpoint_state { + BCE_VHCI_ENDPOINT_ACTIVE = 0, + BCE_VHCI_ENDPOINT_PAUSED = 1, + BCE_VHCI_ENDPOINT_STALLED = 2 +}; + +static inline int bce_vhci_cmd_controller_enable(struct bce_vhci_command_queue *q, u8 busNum, u16 *portMask) +{ + int status; + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_CONTROLLER_ENABLE; + cmd.param1 = 0x7100u | busNum; + status = bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_LONG); + if (!status) + *portMask = (u16) res.param2; + return status; +} +static inline int bce_vhci_cmd_controller_disable(struct bce_vhci_command_queue *q) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_CONTROLLER_DISABLE; + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_LONG); +} +static inline int bce_vhci_cmd_controller_start(struct bce_vhci_command_queue *q) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_CONTROLLER_START; + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_LONG); +} +static inline int bce_vhci_cmd_controller_pause(struct bce_vhci_command_queue *q) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_CONTROLLER_PAUSE; + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_LONG); +} + +static inline int bce_vhci_cmd_port_power_on(struct bce_vhci_command_queue *q, bce_vhci_port_t port) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_PORT_POWER_ON; + cmd.param1 = port; + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); +} +static inline int bce_vhci_cmd_port_power_off(struct bce_vhci_command_queue *q, bce_vhci_port_t port) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_PORT_POWER_OFF; + cmd.param1 = port; + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); +} +static inline int bce_vhci_cmd_port_resume(struct bce_vhci_command_queue *q, bce_vhci_port_t port) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_PORT_RESUME; + cmd.param1 = port; + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_LONG); +} +static inline int bce_vhci_cmd_port_suspend(struct bce_vhci_command_queue *q, bce_vhci_port_t port) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_PORT_SUSPEND; + cmd.param1 = port; + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_LONG); +} +static inline int bce_vhci_cmd_port_reset(struct bce_vhci_command_queue *q, bce_vhci_port_t port, u32 timeout) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_PORT_RESET; + cmd.param1 = port; + cmd.param2 = timeout; + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); +} +static inline int bce_vhci_cmd_port_disable(struct bce_vhci_command_queue *q, bce_vhci_port_t port) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_PORT_DISABLE; + cmd.param1 = port; + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); +} +static inline int bce_vhci_cmd_port_status(struct bce_vhci_command_queue *q, bce_vhci_port_t port, + u32 clearFlags, u32 *resStatus) +{ + int status; + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_PORT_STATUS; + cmd.param1 = port; + cmd.param2 = clearFlags & 0x560000; + status = bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); + if (status >= 0) + *resStatus = (u32) res.param2; + return status; +} + +static inline int bce_vhci_cmd_device_create(struct bce_vhci_command_queue *q, bce_vhci_port_t port, + bce_vhci_device_t *dev) +{ + int status; + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_DEVICE_CREATE; + cmd.param1 = port; + status = bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); + if (!status) + *dev = (bce_vhci_device_t) res.param2; + return status; +} +static inline int bce_vhci_cmd_device_destroy(struct bce_vhci_command_queue *q, bce_vhci_device_t dev) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_DEVICE_DESTROY; + cmd.param1 = dev; + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_LONG); +} + +static inline int bce_vhci_cmd_endpoint_create(struct bce_vhci_command_queue *q, bce_vhci_device_t dev, + struct usb_endpoint_descriptor *desc) +{ + struct bce_vhci_message cmd, res; + int endpoint_type = usb_endpoint_type(desc); + int maxp = usb_endpoint_maxp(desc); + int maxp_burst = usb_endpoint_maxp_mult(desc) * maxp; + u8 max_active_requests_pow2 = 0; + cmd.cmd = BCE_VHCI_CMD_ENDPOINT_CREATE; + cmd.param1 = dev | ((desc->bEndpointAddress & 0x8Fu) << 8); + if (endpoint_type == USB_ENDPOINT_XFER_BULK) + max_active_requests_pow2 = BCE_VHCI_BULK_MAX_ACTIVE_URBS_POW2; + cmd.param2 = endpoint_type | ((max_active_requests_pow2 & 0xf) << 4) | (maxp << 16) | ((u64) maxp_burst << 32); + if (endpoint_type == USB_ENDPOINT_XFER_INT) + cmd.param2 |= (desc->bInterval - 1) << 8; + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); +} +static inline int bce_vhci_cmd_endpoint_destroy(struct bce_vhci_command_queue *q, bce_vhci_device_t dev, u8 endpoint) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_ENDPOINT_DESTROY; + cmd.param1 = dev | (endpoint << 8); + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); +} +static inline int bce_vhci_cmd_endpoint_set_state(struct bce_vhci_command_queue *q, bce_vhci_device_t dev, u8 endpoint, + enum bce_vhci_endpoint_state newState, enum bce_vhci_endpoint_state *retState) +{ + int status; + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_ENDPOINT_SET_STATE; + cmd.param1 = dev | (endpoint << 8); + cmd.param2 = (u64) newState; + status = bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); + if (status != BCE_VHCI_INTERNAL_ERROR && status != BCE_VHCI_NO_POWER) + *retState = (enum bce_vhci_endpoint_state) res.param2; + return status; +} +static inline int bce_vhci_cmd_endpoint_reset(struct bce_vhci_command_queue *q, bce_vhci_device_t dev, u8 endpoint) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_ENDPOINT_RESET; + cmd.param1 = dev | (endpoint << 8); + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); +} + + +#endif //BCE_VHCI_COMMAND_H diff --git a/drivers/staging/apple-bce/vhci/queue.c b/drivers/staging/apple-bce/vhci/queue.c new file mode 100644 index 000000000000..7b0b5027157b --- /dev/null +++ b/drivers/staging/apple-bce/vhci/queue.c @@ -0,0 +1,268 @@ +#include "queue.h" +#include "vhci.h" +#include "../apple_bce.h" + + +static void bce_vhci_message_queue_completion(struct bce_queue_sq *sq); + +int bce_vhci_message_queue_create(struct bce_vhci *vhci, struct bce_vhci_message_queue *ret, const char *name) +{ + int status; + ret->cq = bce_create_cq(vhci->dev, VHCI_EVENT_QUEUE_EL_COUNT); + if (!ret->cq) + return -EINVAL; + ret->sq = bce_create_sq(vhci->dev, ret->cq, name, VHCI_EVENT_QUEUE_EL_COUNT, DMA_TO_DEVICE, + bce_vhci_message_queue_completion, ret); + if (!ret->sq) { + status = -EINVAL; + goto fail_cq; + } + ret->data = dma_alloc_coherent(&vhci->dev->pci->dev, sizeof(struct bce_vhci_message) * VHCI_EVENT_QUEUE_EL_COUNT, + &ret->dma_addr, GFP_KERNEL); + if (!ret->data) { + status = -EINVAL; + goto fail_sq; + } + return 0; + +fail_sq: + bce_destroy_sq(vhci->dev, ret->sq); + ret->sq = NULL; +fail_cq: + bce_destroy_cq(vhci->dev, ret->cq); + ret->cq = NULL; + return status; +} + +void bce_vhci_message_queue_destroy(struct bce_vhci *vhci, struct bce_vhci_message_queue *q) +{ + if (!q->cq) + return; + dma_free_coherent(&vhci->dev->pci->dev, sizeof(struct bce_vhci_message) * VHCI_EVENT_QUEUE_EL_COUNT, + q->data, q->dma_addr); + bce_destroy_sq(vhci->dev, q->sq); + bce_destroy_cq(vhci->dev, q->cq); +} + +void bce_vhci_message_queue_write(struct bce_vhci_message_queue *q, struct bce_vhci_message *req) +{ + int sidx; + struct bce_qe_submission *s; + sidx = q->sq->tail; + s = bce_next_submission(q->sq); + pr_debug("bce-vhci: Send message: %x s=%x p1=%x p2=%llx\n", req->cmd, req->status, req->param1, req->param2); + q->data[sidx] = *req; + bce_set_submission_single(s, q->dma_addr + sizeof(struct bce_vhci_message) * sidx, + sizeof(struct bce_vhci_message)); + bce_submit_to_device(q->sq); +} + +static void bce_vhci_message_queue_completion(struct bce_queue_sq *sq) +{ + while (bce_next_completion(sq)) + bce_notify_submission_complete(sq); +} + + + +static void bce_vhci_event_queue_completion(struct bce_queue_sq *sq); + +int __bce_vhci_event_queue_create(struct bce_vhci *vhci, struct bce_vhci_event_queue *ret, const char *name, + bce_sq_completion compl) +{ + ret->vhci = vhci; + + ret->sq = bce_create_sq(vhci->dev, vhci->ev_cq, name, VHCI_EVENT_QUEUE_EL_COUNT, DMA_FROM_DEVICE, compl, ret); + if (!ret->sq) + return -EINVAL; + ret->data = dma_alloc_coherent(&vhci->dev->pci->dev, sizeof(struct bce_vhci_message) * VHCI_EVENT_QUEUE_EL_COUNT, + &ret->dma_addr, GFP_KERNEL); + if (!ret->data) { + bce_destroy_sq(vhci->dev, ret->sq); + ret->sq = NULL; + return -EINVAL; + } + + init_completion(&ret->queue_empty_completion); + bce_vhci_event_queue_submit_pending(ret, VHCI_EVENT_PENDING_COUNT); + return 0; +} + +int bce_vhci_event_queue_create(struct bce_vhci *vhci, struct bce_vhci_event_queue *ret, const char *name, + bce_vhci_event_queue_callback cb) +{ + ret->cb = cb; + return __bce_vhci_event_queue_create(vhci, ret, name, bce_vhci_event_queue_completion); +} + +void bce_vhci_event_queue_destroy(struct bce_vhci *vhci, struct bce_vhci_event_queue *q) +{ + if (!q->sq) + return; + dma_free_coherent(&vhci->dev->pci->dev, sizeof(struct bce_vhci_message) * VHCI_EVENT_QUEUE_EL_COUNT, + q->data, q->dma_addr); + bce_destroy_sq(vhci->dev, q->sq); +} + +static void bce_vhci_event_queue_completion(struct bce_queue_sq *sq) +{ + struct bce_sq_completion_data *cd; + struct bce_vhci_event_queue *ev = sq->userdata; + struct bce_vhci_message *msg; + size_t cnt = 0; + + while ((cd = bce_next_completion(sq))) { + if (cd->status == BCE_COMPLETION_ABORTED) { /* We flushed the queue */ + bce_notify_submission_complete(sq); + continue; + } + msg = &ev->data[sq->head]; + pr_debug("bce-vhci: Got event: %x s=%x p1=%x p2=%llx\n", msg->cmd, msg->status, msg->param1, msg->param2); + ev->cb(ev, msg); + + bce_notify_submission_complete(sq); + ++cnt; + } + bce_vhci_event_queue_submit_pending(ev, cnt); + if (atomic_read(&sq->available_commands) == sq->el_count - 1) + complete(&ev->queue_empty_completion); +} + +void bce_vhci_event_queue_submit_pending(struct bce_vhci_event_queue *q, size_t count) +{ + int idx; + struct bce_qe_submission *s; + while (count--) { + if (bce_reserve_submission(q->sq, NULL)) { + pr_err("bce-vhci: Failed to reserve an event queue submission\n"); + break; + } + idx = q->sq->tail; + s = bce_next_submission(q->sq); + bce_set_submission_single(s, + q->dma_addr + idx * sizeof(struct bce_vhci_message), sizeof(struct bce_vhci_message)); + } + bce_submit_to_device(q->sq); +} + +void bce_vhci_event_queue_pause(struct bce_vhci_event_queue *q) +{ + unsigned long timeout; + reinit_completion(&q->queue_empty_completion); + if (bce_cmd_flush_memory_queue(q->vhci->dev->cmd_cmdq, q->sq->qid)) + pr_warn("bce-vhci: failed to flush event queue\n"); + timeout = msecs_to_jiffies(5000); + while (atomic_read(&q->sq->available_commands) != q->sq->el_count - 1) { + timeout = wait_for_completion_timeout(&q->queue_empty_completion, timeout); + if (timeout == 0) { + pr_err("bce-vhci: waiting for queue to be flushed timed out\n"); + break; + } + } +} + +void bce_vhci_event_queue_resume(struct bce_vhci_event_queue *q) +{ + if (atomic_read(&q->sq->available_commands) != q->sq->el_count - 1) { + pr_err("bce-vhci: resume of a queue with pending submissions\n"); + return; + } + bce_vhci_event_queue_submit_pending(q, VHCI_EVENT_PENDING_COUNT); +} + +void bce_vhci_command_queue_create(struct bce_vhci_command_queue *ret, struct bce_vhci_message_queue *mq) +{ + ret->mq = mq; + ret->completion.result = NULL; + init_completion(&ret->completion.completion); + spin_lock_init(&ret->completion_lock); + mutex_init(&ret->mutex); +} + +void bce_vhci_command_queue_destroy(struct bce_vhci_command_queue *cq) +{ + spin_lock(&cq->completion_lock); + if (cq->completion.result) { + memset(cq->completion.result, 0, sizeof(struct bce_vhci_message)); + cq->completion.result->status = BCE_VHCI_ABORT; + complete(&cq->completion.completion); + cq->completion.result = NULL; + } + spin_unlock(&cq->completion_lock); + mutex_lock(&cq->mutex); + mutex_unlock(&cq->mutex); + mutex_destroy(&cq->mutex); +} + +void bce_vhci_command_queue_deliver_completion(struct bce_vhci_command_queue *cq, struct bce_vhci_message *msg) +{ + struct bce_vhci_command_queue_completion *c = &cq->completion; + + spin_lock(&cq->completion_lock); + if (c->result) { + *c->result = *msg; + complete(&c->completion); + c->result = NULL; + } + spin_unlock(&cq->completion_lock); +} + +static int __bce_vhci_command_queue_execute(struct bce_vhci_command_queue *cq, struct bce_vhci_message *req, + struct bce_vhci_message *res, unsigned long timeout) +{ + int status; + struct bce_vhci_command_queue_completion *c; + struct bce_vhci_message creq; + c = &cq->completion; + + if ((status = bce_reserve_submission(cq->mq->sq, &timeout))) + return status; + + spin_lock(&cq->completion_lock); + c->result = res; + reinit_completion(&c->completion); + spin_unlock(&cq->completion_lock); + + bce_vhci_message_queue_write(cq->mq, req); + + if (!wait_for_completion_timeout(&c->completion, timeout)) { + /* we ran out of time, send cancellation */ + pr_debug("bce-vhci: command timed out req=%x\n", req->cmd); + if ((status = bce_reserve_submission(cq->mq->sq, &timeout))) + return status; + + creq = *req; + creq.cmd |= 0x4000; + bce_vhci_message_queue_write(cq->mq, &creq); + + if (!wait_for_completion_timeout(&c->completion, 1000)) { + pr_err("bce-vhci: Possible desync, cmd cancel timed out\n"); + + spin_lock(&cq->completion_lock); + c->result = NULL; + spin_unlock(&cq->completion_lock); + return -ETIMEDOUT; + } + if ((res->cmd & ~0x8000) == creq.cmd) + return -ETIMEDOUT; + /* reply for the previous command most likely arrived */ + } + + if ((res->cmd & ~0x8000) != req->cmd) { + pr_err("bce-vhci: Possible desync, cmd reply mismatch req=%x, res=%x\n", req->cmd, res->cmd); + return -EIO; + } + if (res->status == BCE_VHCI_SUCCESS) + return 0; + return res->status; +} + +int bce_vhci_command_queue_execute(struct bce_vhci_command_queue *cq, struct bce_vhci_message *req, + struct bce_vhci_message *res, unsigned long timeout) +{ + int status; + mutex_lock(&cq->mutex); + status = __bce_vhci_command_queue_execute(cq, req, res, timeout); + mutex_unlock(&cq->mutex); + return status; +} diff --git a/drivers/staging/apple-bce/vhci/queue.h b/drivers/staging/apple-bce/vhci/queue.h new file mode 100644 index 000000000000..adb705b6ba1d --- /dev/null +++ b/drivers/staging/apple-bce/vhci/queue.h @@ -0,0 +1,76 @@ +#ifndef BCE_VHCI_QUEUE_H +#define BCE_VHCI_QUEUE_H + +#include +#include "../queue.h" + +#define VHCI_EVENT_QUEUE_EL_COUNT 256 +#define VHCI_EVENT_PENDING_COUNT 32 + +struct bce_vhci; +struct bce_vhci_event_queue; + +enum bce_vhci_message_status { + BCE_VHCI_SUCCESS = 1, + BCE_VHCI_ERROR = 2, + BCE_VHCI_USB_PIPE_STALL = 3, + BCE_VHCI_ABORT = 4, + BCE_VHCI_BAD_ARGUMENT = 5, + BCE_VHCI_OVERRUN = 6, + BCE_VHCI_INTERNAL_ERROR = 7, + BCE_VHCI_NO_POWER = 8, + BCE_VHCI_UNSUPPORTED = 9 +}; +struct bce_vhci_message { + u16 cmd; + u16 status; // bce_vhci_message_status + u32 param1; + u64 param2; +}; + +struct bce_vhci_message_queue { + struct bce_queue_cq *cq; + struct bce_queue_sq *sq; + struct bce_vhci_message *data; + dma_addr_t dma_addr; +}; +typedef void (*bce_vhci_event_queue_callback)(struct bce_vhci_event_queue *q, struct bce_vhci_message *msg); +struct bce_vhci_event_queue { + struct bce_vhci *vhci; + struct bce_queue_sq *sq; + struct bce_vhci_message *data; + dma_addr_t dma_addr; + bce_vhci_event_queue_callback cb; + struct completion queue_empty_completion; +}; +struct bce_vhci_command_queue_completion { + struct bce_vhci_message *result; + struct completion completion; +}; +struct bce_vhci_command_queue { + struct bce_vhci_message_queue *mq; + struct bce_vhci_command_queue_completion completion; + struct spinlock completion_lock; + struct mutex mutex; +}; + +int bce_vhci_message_queue_create(struct bce_vhci *vhci, struct bce_vhci_message_queue *ret, const char *name); +void bce_vhci_message_queue_destroy(struct bce_vhci *vhci, struct bce_vhci_message_queue *q); +void bce_vhci_message_queue_write(struct bce_vhci_message_queue *q, struct bce_vhci_message *req); + +int __bce_vhci_event_queue_create(struct bce_vhci *vhci, struct bce_vhci_event_queue *ret, const char *name, + bce_sq_completion compl); +int bce_vhci_event_queue_create(struct bce_vhci *vhci, struct bce_vhci_event_queue *ret, const char *name, + bce_vhci_event_queue_callback cb); +void bce_vhci_event_queue_destroy(struct bce_vhci *vhci, struct bce_vhci_event_queue *q); +void bce_vhci_event_queue_submit_pending(struct bce_vhci_event_queue *q, size_t count); +void bce_vhci_event_queue_pause(struct bce_vhci_event_queue *q); +void bce_vhci_event_queue_resume(struct bce_vhci_event_queue *q); + +void bce_vhci_command_queue_create(struct bce_vhci_command_queue *ret, struct bce_vhci_message_queue *mq); +void bce_vhci_command_queue_destroy(struct bce_vhci_command_queue *cq); +int bce_vhci_command_queue_execute(struct bce_vhci_command_queue *cq, struct bce_vhci_message *req, + struct bce_vhci_message *res, unsigned long timeout); +void bce_vhci_command_queue_deliver_completion(struct bce_vhci_command_queue *cq, struct bce_vhci_message *msg); + +#endif //BCE_VHCI_QUEUE_H diff --git a/drivers/staging/apple-bce/vhci/transfer.c b/drivers/staging/apple-bce/vhci/transfer.c new file mode 100644 index 000000000000..8226363d69c8 --- /dev/null +++ b/drivers/staging/apple-bce/vhci/transfer.c @@ -0,0 +1,661 @@ +#include "transfer.h" +#include "../queue.h" +#include "vhci.h" +#include "../apple_bce.h" +#include + +static void bce_vhci_transfer_queue_completion(struct bce_queue_sq *sq); +static void bce_vhci_transfer_queue_giveback(struct bce_vhci_transfer_queue *q); +static void bce_vhci_transfer_queue_remove_pending(struct bce_vhci_transfer_queue *q); + +static int bce_vhci_urb_init(struct bce_vhci_urb *vurb); +static int bce_vhci_urb_update(struct bce_vhci_urb *urb, struct bce_vhci_message *msg); +static int bce_vhci_urb_transfer_completion(struct bce_vhci_urb *urb, struct bce_sq_completion_data *c); + +static void bce_vhci_transfer_queue_reset_w(struct work_struct *work); + +void bce_vhci_create_transfer_queue(struct bce_vhci *vhci, struct bce_vhci_transfer_queue *q, + struct usb_host_endpoint *endp, bce_vhci_device_t dev_addr, enum dma_data_direction dir) +{ + char name[0x21]; + INIT_LIST_HEAD(&q->evq); + INIT_LIST_HEAD(&q->giveback_urb_list); + spin_lock_init(&q->urb_lock); + mutex_init(&q->pause_lock); + q->vhci = vhci; + q->endp = endp; + q->dev_addr = dev_addr; + q->endp_addr = (u8) (endp->desc.bEndpointAddress & 0x8F); + q->state = BCE_VHCI_ENDPOINT_ACTIVE; + q->active = true; + q->stalled = false; + q->max_active_requests = 1; + if (usb_endpoint_type(&endp->desc) == USB_ENDPOINT_XFER_BULK) + q->max_active_requests = BCE_VHCI_BULK_MAX_ACTIVE_URBS; + q->remaining_active_requests = q->max_active_requests; + q->cq = bce_create_cq(vhci->dev, 0x100); + INIT_WORK(&q->w_reset, bce_vhci_transfer_queue_reset_w); + q->sq_in = NULL; + if (dir == DMA_FROM_DEVICE || dir == DMA_BIDIRECTIONAL) { + snprintf(name, sizeof(name), "VHC1-%i-%02x", dev_addr, 0x80 | usb_endpoint_num(&endp->desc)); + q->sq_in = bce_create_sq(vhci->dev, q->cq, name, 0x100, DMA_FROM_DEVICE, + bce_vhci_transfer_queue_completion, q); + } + q->sq_out = NULL; + if (dir == DMA_TO_DEVICE || dir == DMA_BIDIRECTIONAL) { + snprintf(name, sizeof(name), "VHC1-%i-%02x", dev_addr, usb_endpoint_num(&endp->desc)); + q->sq_out = bce_create_sq(vhci->dev, q->cq, name, 0x100, DMA_TO_DEVICE, + bce_vhci_transfer_queue_completion, q); + } +} + +void bce_vhci_destroy_transfer_queue(struct bce_vhci *vhci, struct bce_vhci_transfer_queue *q) +{ + bce_vhci_transfer_queue_giveback(q); + bce_vhci_transfer_queue_remove_pending(q); + if (q->sq_in) + bce_destroy_sq(vhci->dev, q->sq_in); + if (q->sq_out) + bce_destroy_sq(vhci->dev, q->sq_out); + bce_destroy_cq(vhci->dev, q->cq); +} + +static inline bool bce_vhci_transfer_queue_can_init_urb(struct bce_vhci_transfer_queue *q) +{ + return q->remaining_active_requests > 0; +} + +static void bce_vhci_transfer_queue_defer_event(struct bce_vhci_transfer_queue *q, struct bce_vhci_message *msg) +{ + struct bce_vhci_list_message *lm; + lm = kmalloc(sizeof(struct bce_vhci_list_message), GFP_KERNEL); + INIT_LIST_HEAD(&lm->list); + lm->msg = *msg; + list_add_tail(&lm->list, &q->evq); +} + +static void bce_vhci_transfer_queue_giveback(struct bce_vhci_transfer_queue *q) +{ + unsigned long flags; + struct urb *urb; + spin_lock_irqsave(&q->urb_lock, flags); + while (!list_empty(&q->giveback_urb_list)) { + urb = list_first_entry(&q->giveback_urb_list, struct urb, urb_list); + list_del(&urb->urb_list); + + spin_unlock_irqrestore(&q->urb_lock, flags); + usb_hcd_giveback_urb(q->vhci->hcd, urb, urb->status); + spin_lock_irqsave(&q->urb_lock, flags); + } + spin_unlock_irqrestore(&q->urb_lock, flags); +} + +static void bce_vhci_transfer_queue_init_pending_urbs(struct bce_vhci_transfer_queue *q); + +static void bce_vhci_transfer_queue_deliver_pending(struct bce_vhci_transfer_queue *q) +{ + struct urb *urb; + struct bce_vhci_list_message *lm; + + while (!list_empty(&q->endp->urb_list) && !list_empty(&q->evq)) { + urb = list_first_entry(&q->endp->urb_list, struct urb, urb_list); + + lm = list_first_entry(&q->evq, struct bce_vhci_list_message, list); + if (bce_vhci_urb_update(urb->hcpriv, &lm->msg) == -EAGAIN) + break; + list_del(&lm->list); + kfree(lm); + } + + /* some of the URBs could have been completed, so initialize more URBs if possible */ + bce_vhci_transfer_queue_init_pending_urbs(q); +} + +static void bce_vhci_transfer_queue_remove_pending(struct bce_vhci_transfer_queue *q) +{ + unsigned long flags; + struct bce_vhci_list_message *lm; + spin_lock_irqsave(&q->urb_lock, flags); + while (!list_empty(&q->evq)) { + lm = list_first_entry(&q->evq, struct bce_vhci_list_message, list); + list_del(&lm->list); + kfree(lm); + } + spin_unlock_irqrestore(&q->urb_lock, flags); +} + +void bce_vhci_transfer_queue_event(struct bce_vhci_transfer_queue *q, struct bce_vhci_message *msg) +{ + unsigned long flags; + struct bce_vhci_urb *turb; + struct urb *urb; + spin_lock_irqsave(&q->urb_lock, flags); + bce_vhci_transfer_queue_deliver_pending(q); + + if (msg->cmd == BCE_VHCI_CMD_TRANSFER_REQUEST && + (!list_empty(&q->evq) || list_empty(&q->endp->urb_list))) { + bce_vhci_transfer_queue_defer_event(q, msg); + goto complete; + } + if (list_empty(&q->endp->urb_list)) { + pr_err("bce-vhci: [%02x] Unexpected transfer queue event\n", q->endp_addr); + goto complete; + } + urb = list_first_entry(&q->endp->urb_list, struct urb, urb_list); + turb = urb->hcpriv; + if (bce_vhci_urb_update(turb, msg) == -EAGAIN) { + bce_vhci_transfer_queue_defer_event(q, msg); + } else { + bce_vhci_transfer_queue_init_pending_urbs(q); + } + +complete: + spin_unlock_irqrestore(&q->urb_lock, flags); + bce_vhci_transfer_queue_giveback(q); +} + +static void bce_vhci_transfer_queue_completion(struct bce_queue_sq *sq) +{ + unsigned long flags; + struct bce_sq_completion_data *c; + struct urb *urb; + struct bce_vhci_transfer_queue *q = sq->userdata; + spin_lock_irqsave(&q->urb_lock, flags); + while ((c = bce_next_completion(sq))) { + if (c->status == BCE_COMPLETION_ABORTED) { /* We flushed the queue */ + pr_debug("bce-vhci: [%02x] Got an abort completion\n", q->endp_addr); + bce_notify_submission_complete(sq); + continue; + } + if (list_empty(&q->endp->urb_list)) { + pr_err("bce-vhci: [%02x] Got a completion while no requests are pending\n", q->endp_addr); + continue; + } + pr_debug("bce-vhci: [%02x] Got a transfer queue completion\n", q->endp_addr); + urb = list_first_entry(&q->endp->urb_list, struct urb, urb_list); + bce_vhci_urb_transfer_completion(urb->hcpriv, c); + bce_notify_submission_complete(sq); + } + bce_vhci_transfer_queue_deliver_pending(q); + spin_unlock_irqrestore(&q->urb_lock, flags); + bce_vhci_transfer_queue_giveback(q); +} + +int bce_vhci_transfer_queue_do_pause(struct bce_vhci_transfer_queue *q) +{ + unsigned long flags; + int status; + u8 endp_addr = (u8) (q->endp->desc.bEndpointAddress & 0x8F); + spin_lock_irqsave(&q->urb_lock, flags); + q->active = false; + spin_unlock_irqrestore(&q->urb_lock, flags); + if (q->sq_out) { + pr_err("bce-vhci: Not implemented: wait for pending output requests\n"); + } + bce_vhci_transfer_queue_remove_pending(q); + if ((status = bce_vhci_cmd_endpoint_set_state( + &q->vhci->cq, q->dev_addr, endp_addr, BCE_VHCI_ENDPOINT_PAUSED, &q->state))) + return status; + if (q->state != BCE_VHCI_ENDPOINT_PAUSED) + return -EINVAL; + if (q->sq_in) + bce_cmd_flush_memory_queue(q->vhci->dev->cmd_cmdq, (u16) q->sq_in->qid); + if (q->sq_out) + bce_cmd_flush_memory_queue(q->vhci->dev->cmd_cmdq, (u16) q->sq_out->qid); + return 0; +} + +static void bce_vhci_urb_resume(struct bce_vhci_urb *urb); + +int bce_vhci_transfer_queue_do_resume(struct bce_vhci_transfer_queue *q) +{ + unsigned long flags; + int status; + struct urb *urb, *urbt; + struct bce_vhci_urb *vurb; + u8 endp_addr = (u8) (q->endp->desc.bEndpointAddress & 0x8F); + if ((status = bce_vhci_cmd_endpoint_set_state( + &q->vhci->cq, q->dev_addr, endp_addr, BCE_VHCI_ENDPOINT_ACTIVE, &q->state))) + return status; + if (q->state != BCE_VHCI_ENDPOINT_ACTIVE) + return -EINVAL; + spin_lock_irqsave(&q->urb_lock, flags); + q->active = true; + list_for_each_entry_safe(urb, urbt, &q->endp->urb_list, urb_list) { + vurb = urb->hcpriv; + if (vurb->state == BCE_VHCI_URB_INIT_PENDING) { + if (!bce_vhci_transfer_queue_can_init_urb(q)) + break; + bce_vhci_urb_init(vurb); + } else { + bce_vhci_urb_resume(vurb); + } + } + bce_vhci_transfer_queue_deliver_pending(q); + spin_unlock_irqrestore(&q->urb_lock, flags); + return 0; +} + +int bce_vhci_transfer_queue_pause(struct bce_vhci_transfer_queue *q, enum bce_vhci_pause_source src) +{ + int ret = 0; + mutex_lock(&q->pause_lock); + if ((q->paused_by & src) != src) { + if (!q->paused_by) + ret = bce_vhci_transfer_queue_do_pause(q); + if (!ret) + q->paused_by |= src; + } + mutex_unlock(&q->pause_lock); + return ret; +} + +int bce_vhci_transfer_queue_resume(struct bce_vhci_transfer_queue *q, enum bce_vhci_pause_source src) +{ + int ret = 0; + mutex_lock(&q->pause_lock); + if (q->paused_by & src) { + if (!(q->paused_by & ~src)) + ret = bce_vhci_transfer_queue_do_resume(q); + if (!ret) + q->paused_by &= ~src; + } + mutex_unlock(&q->pause_lock); + return ret; +} + +static void bce_vhci_transfer_queue_reset_w(struct work_struct *work) +{ + unsigned long flags; + struct bce_vhci_transfer_queue *q = container_of(work, struct bce_vhci_transfer_queue, w_reset); + + mutex_lock(&q->pause_lock); + spin_lock_irqsave(&q->urb_lock, flags); + if (!q->stalled) { + spin_unlock_irqrestore(&q->urb_lock, flags); + mutex_unlock(&q->pause_lock); + return; + } + q->active = false; + spin_unlock_irqrestore(&q->urb_lock, flags); + q->paused_by |= BCE_VHCI_PAUSE_INTERNAL_WQ; + bce_vhci_transfer_queue_remove_pending(q); + if (q->sq_in) + bce_cmd_flush_memory_queue(q->vhci->dev->cmd_cmdq, (u16) q->sq_in->qid); + if (q->sq_out) + bce_cmd_flush_memory_queue(q->vhci->dev->cmd_cmdq, (u16) q->sq_out->qid); + bce_vhci_cmd_endpoint_reset(&q->vhci->cq, q->dev_addr, (u8) (q->endp->desc.bEndpointAddress & 0x8F)); + spin_lock_irqsave(&q->urb_lock, flags); + q->stalled = false; + spin_unlock_irqrestore(&q->urb_lock, flags); + mutex_unlock(&q->pause_lock); + bce_vhci_transfer_queue_resume(q, BCE_VHCI_PAUSE_INTERNAL_WQ); +} + +void bce_vhci_transfer_queue_request_reset(struct bce_vhci_transfer_queue *q) +{ + queue_work(q->vhci->tq_state_wq, &q->w_reset); +} + +static void bce_vhci_transfer_queue_init_pending_urbs(struct bce_vhci_transfer_queue *q) +{ + struct urb *urb, *urbt; + struct bce_vhci_urb *vurb; + list_for_each_entry_safe(urb, urbt, &q->endp->urb_list, urb_list) { + vurb = urb->hcpriv; + if (!bce_vhci_transfer_queue_can_init_urb(q)) + break; + if (vurb->state == BCE_VHCI_URB_INIT_PENDING) + bce_vhci_urb_init(vurb); + } +} + + + +static int bce_vhci_urb_data_start(struct bce_vhci_urb *urb, unsigned long *timeout); + +int bce_vhci_urb_create(struct bce_vhci_transfer_queue *q, struct urb *urb) +{ + unsigned long flags; + int status = 0; + struct bce_vhci_urb *vurb; + vurb = kzalloc(sizeof(struct bce_vhci_urb), GFP_KERNEL); + urb->hcpriv = vurb; + + vurb->q = q; + vurb->urb = urb; + vurb->dir = usb_urb_dir_in(urb) ? DMA_FROM_DEVICE : DMA_TO_DEVICE; + vurb->is_control = (usb_endpoint_num(&urb->ep->desc) == 0); + + spin_lock_irqsave(&q->urb_lock, flags); + status = usb_hcd_link_urb_to_ep(q->vhci->hcd, urb); + if (status) { + spin_unlock_irqrestore(&q->urb_lock, flags); + urb->hcpriv = NULL; + kfree(vurb); + return status; + } + + if (q->active) { + if (bce_vhci_transfer_queue_can_init_urb(vurb->q)) + status = bce_vhci_urb_init(vurb); + else + vurb->state = BCE_VHCI_URB_INIT_PENDING; + } else { + if (q->stalled) + bce_vhci_transfer_queue_request_reset(q); + vurb->state = BCE_VHCI_URB_INIT_PENDING; + } + if (status) { + usb_hcd_unlink_urb_from_ep(q->vhci->hcd, urb); + urb->hcpriv = NULL; + kfree(vurb); + } else { + bce_vhci_transfer_queue_deliver_pending(q); + } + spin_unlock_irqrestore(&q->urb_lock, flags); + pr_debug("bce-vhci: [%02x] URB enqueued (dir = %s, size = %i)\n", q->endp_addr, + usb_urb_dir_in(urb) ? "IN" : "OUT", urb->transfer_buffer_length); + return status; +} + +static int bce_vhci_urb_init(struct bce_vhci_urb *vurb) +{ + int status = 0; + + if (vurb->q->remaining_active_requests == 0) { + pr_err("bce-vhci: cannot init request (remaining_active_requests = 0)\n"); + return -EINVAL; + } + + if (vurb->is_control) { + vurb->state = BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_REQUEST; + } else { + status = bce_vhci_urb_data_start(vurb, NULL); + } + + if (!status) { + --vurb->q->remaining_active_requests; + } + return status; +} + +static void bce_vhci_urb_complete(struct bce_vhci_urb *urb, int status) +{ + struct bce_vhci_transfer_queue *q = urb->q; + struct bce_vhci *vhci = q->vhci; + struct urb *real_urb = urb->urb; + pr_debug("bce-vhci: [%02x] URB complete %i\n", q->endp_addr, status); + usb_hcd_unlink_urb_from_ep(vhci->hcd, real_urb); + real_urb->hcpriv = NULL; + real_urb->status = status; + if (urb->state != BCE_VHCI_URB_INIT_PENDING) + ++urb->q->remaining_active_requests; + kfree(urb); + list_add_tail(&real_urb->urb_list, &q->giveback_urb_list); +} + +int bce_vhci_urb_request_cancel(struct bce_vhci_transfer_queue *q, struct urb *urb, int status) +{ + struct bce_vhci_urb *vurb; + unsigned long flags; + int ret; + + spin_lock_irqsave(&q->urb_lock, flags); + if ((ret = usb_hcd_check_unlink_urb(q->vhci->hcd, urb, status))) { + spin_unlock_irqrestore(&q->urb_lock, flags); + return ret; + } + + vurb = urb->hcpriv; + /* If the URB wasn't posted to the device yet, we can still remove it on the host without pausing the queue. */ + if (vurb->state != BCE_VHCI_URB_INIT_PENDING) { + pr_debug("bce-vhci: [%02x] Cancelling URB\n", q->endp_addr); + + spin_unlock_irqrestore(&q->urb_lock, flags); + bce_vhci_transfer_queue_pause(q, BCE_VHCI_PAUSE_INTERNAL_WQ); + spin_lock_irqsave(&q->urb_lock, flags); + + ++q->remaining_active_requests; + } + + usb_hcd_unlink_urb_from_ep(q->vhci->hcd, urb); + + spin_unlock_irqrestore(&q->urb_lock, flags); + + usb_hcd_giveback_urb(q->vhci->hcd, urb, status); + + if (vurb->state != BCE_VHCI_URB_INIT_PENDING) + bce_vhci_transfer_queue_resume(q, BCE_VHCI_PAUSE_INTERNAL_WQ); + + kfree(vurb); + + return 0; +} + +static int bce_vhci_urb_data_transfer_in(struct bce_vhci_urb *urb, unsigned long *timeout) +{ + struct bce_vhci_message msg; + struct bce_qe_submission *s; + u32 tr_len; + int reservation1, reservation2 = -EFAULT; + + pr_debug("bce-vhci: [%02x] DMA from device %llx %x\n", urb->q->endp_addr, + (u64) urb->urb->transfer_dma, urb->urb->transfer_buffer_length); + + /* Reserve both a message and a submission, so we don't run into issues later. */ + reservation1 = bce_reserve_submission(urb->q->vhci->msg_asynchronous.sq, timeout); + if (!reservation1) + reservation2 = bce_reserve_submission(urb->q->sq_in, timeout); + if (reservation1 || reservation2) { + pr_err("bce-vhci: Failed to reserve a submission for URB data transfer\n"); + if (!reservation1) + bce_cancel_submission_reservation(urb->q->vhci->msg_asynchronous.sq); + return -ENOMEM; + } + + urb->send_offset = urb->receive_offset; + + tr_len = urb->urb->transfer_buffer_length - urb->send_offset; + + spin_lock(&urb->q->vhci->msg_asynchronous_lock); + msg.cmd = BCE_VHCI_CMD_TRANSFER_REQUEST; + msg.status = 0; + msg.param1 = ((urb->urb->ep->desc.bEndpointAddress & 0x8Fu) << 8) | urb->q->dev_addr; + msg.param2 = tr_len; + bce_vhci_message_queue_write(&urb->q->vhci->msg_asynchronous, &msg); + spin_unlock(&urb->q->vhci->msg_asynchronous_lock); + + s = bce_next_submission(urb->q->sq_in); + bce_set_submission_single(s, urb->urb->transfer_dma + urb->send_offset, tr_len); + bce_submit_to_device(urb->q->sq_in); + + urb->state = BCE_VHCI_URB_WAITING_FOR_COMPLETION; + return 0; +} + +static int bce_vhci_urb_data_start(struct bce_vhci_urb *urb, unsigned long *timeout) +{ + if (urb->dir == DMA_TO_DEVICE) { + if (urb->urb->transfer_buffer_length > 0) + urb->state = BCE_VHCI_URB_WAITING_FOR_TRANSFER_REQUEST; + else + urb->state = BCE_VHCI_URB_DATA_TRANSFER_COMPLETE; + return 0; + } else { + return bce_vhci_urb_data_transfer_in(urb, timeout); + } +} + +static int bce_vhci_urb_send_out_data(struct bce_vhci_urb *urb, dma_addr_t addr, size_t size) +{ + struct bce_qe_submission *s; + unsigned long timeout = 0; + if (bce_reserve_submission(urb->q->sq_out, &timeout)) { + pr_err("bce-vhci: Failed to reserve a submission for URB data transfer\n"); + return -EPIPE; + } + + pr_debug("bce-vhci: [%02x] DMA to device %llx %lx\n", urb->q->endp_addr, (u64) addr, size); + + s = bce_next_submission(urb->q->sq_out); + bce_set_submission_single(s, addr, size); + bce_submit_to_device(urb->q->sq_out); + return 0; +} + +static int bce_vhci_urb_data_update(struct bce_vhci_urb *urb, struct bce_vhci_message *msg) +{ + u32 tr_len; + int status; + if (urb->state == BCE_VHCI_URB_WAITING_FOR_TRANSFER_REQUEST) { + if (msg->cmd == BCE_VHCI_CMD_TRANSFER_REQUEST) { + tr_len = min(urb->urb->transfer_buffer_length - urb->send_offset, (u32) msg->param2); + if ((status = bce_vhci_urb_send_out_data(urb, urb->urb->transfer_dma + urb->send_offset, tr_len))) + return status; + urb->send_offset += tr_len; + urb->state = BCE_VHCI_URB_WAITING_FOR_COMPLETION; + return 0; + } + } + + /* 0x1000 in out queues aren't really unexpected */ + if (msg->cmd == BCE_VHCI_CMD_TRANSFER_REQUEST && urb->q->sq_out != NULL) + return -EAGAIN; + pr_err("bce-vhci: [%02x] %s URB unexpected message (state = %x, msg: %x %x %x %llx)\n", + urb->q->endp_addr, (urb->is_control ? "Control (data update)" : "Data"), urb->state, + msg->cmd, msg->status, msg->param1, msg->param2); + return -EAGAIN; +} + +static int bce_vhci_urb_data_transfer_completion(struct bce_vhci_urb *urb, struct bce_sq_completion_data *c) +{ + if (urb->state == BCE_VHCI_URB_WAITING_FOR_COMPLETION) { + urb->receive_offset += c->data_size; + if (urb->dir == DMA_FROM_DEVICE || urb->receive_offset >= urb->urb->transfer_buffer_length) { + urb->urb->actual_length = (u32) urb->receive_offset; + urb->state = BCE_VHCI_URB_DATA_TRANSFER_COMPLETE; + if (!urb->is_control) { + bce_vhci_urb_complete(urb, 0); + return -ENOENT; + } + } + } else { + pr_err("bce-vhci: [%02x] Data URB unexpected completion\n", urb->q->endp_addr); + } + return 0; +} + + +static int bce_vhci_urb_control_check_status(struct bce_vhci_urb *urb) +{ + struct bce_vhci_transfer_queue *q = urb->q; + if (urb->received_status == 0) + return 0; + if (urb->state == BCE_VHCI_URB_DATA_TRANSFER_COMPLETE || + (urb->received_status != BCE_VHCI_SUCCESS && urb->state != BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_REQUEST && + urb->state != BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_COMPLETION)) { + urb->state = BCE_VHCI_URB_CONTROL_COMPLETE; + if (urb->received_status != BCE_VHCI_SUCCESS) { + pr_err("bce-vhci: [%02x] URB failed: %x\n", urb->q->endp_addr, urb->received_status); + urb->q->active = false; + urb->q->stalled = true; + bce_vhci_urb_complete(urb, -EPIPE); + if (!list_empty(&q->endp->urb_list)) + bce_vhci_transfer_queue_request_reset(q); + return -ENOENT; + } + bce_vhci_urb_complete(urb, 0); + return -ENOENT; + } + return 0; +} + +static int bce_vhci_urb_control_update(struct bce_vhci_urb *urb, struct bce_vhci_message *msg) +{ + int status; + if (msg->cmd == BCE_VHCI_CMD_CONTROL_TRANSFER_STATUS) { + urb->received_status = msg->status; + return bce_vhci_urb_control_check_status(urb); + } + + if (urb->state == BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_REQUEST) { + if (msg->cmd == BCE_VHCI_CMD_TRANSFER_REQUEST) { + if (bce_vhci_urb_send_out_data(urb, urb->urb->setup_dma, sizeof(struct usb_ctrlrequest))) { + pr_err("bce-vhci: [%02x] Failed to start URB setup transfer\n", urb->q->endp_addr); + return 0; /* TODO: fail the URB? */ + } + urb->state = BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_COMPLETION; + pr_debug("bce-vhci: [%02x] Sent setup %llx\n", urb->q->endp_addr, urb->urb->setup_dma); + return 0; + } + } else if (urb->state == BCE_VHCI_URB_WAITING_FOR_TRANSFER_REQUEST || + urb->state == BCE_VHCI_URB_WAITING_FOR_COMPLETION) { + if ((status = bce_vhci_urb_data_update(urb, msg))) + return status; + return bce_vhci_urb_control_check_status(urb); + } + + /* 0x1000 in out queues aren't really unexpected */ + if (msg->cmd == BCE_VHCI_CMD_TRANSFER_REQUEST && urb->q->sq_out != NULL) + return -EAGAIN; + pr_err("bce-vhci: [%02x] Control URB unexpected message (state = %x, msg: %x %x %x %llx)\n", urb->q->endp_addr, + urb->state, msg->cmd, msg->status, msg->param1, msg->param2); + return -EAGAIN; +} + +static int bce_vhci_urb_control_transfer_completion(struct bce_vhci_urb *urb, struct bce_sq_completion_data *c) +{ + int status; + unsigned long timeout; + + if (urb->state == BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_COMPLETION) { + if (c->data_size != sizeof(struct usb_ctrlrequest)) + pr_err("bce-vhci: [%02x] transfer complete data size mistmatch for usb_ctrlrequest (%llx instead of %lx)\n", + urb->q->endp_addr, c->data_size, sizeof(struct usb_ctrlrequest)); + + timeout = 1000; + status = bce_vhci_urb_data_start(urb, &timeout); + if (status) { + bce_vhci_urb_complete(urb, status); + return -ENOENT; + } + return 0; + } else if (urb->state == BCE_VHCI_URB_WAITING_FOR_TRANSFER_REQUEST || + urb->state == BCE_VHCI_URB_WAITING_FOR_COMPLETION) { + if ((status = bce_vhci_urb_data_transfer_completion(urb, c))) + return status; + return bce_vhci_urb_control_check_status(urb); + } else { + pr_err("bce-vhci: [%02x] Control URB unexpected completion (state = %x)\n", urb->q->endp_addr, urb->state); + } + return 0; +} + +static int bce_vhci_urb_update(struct bce_vhci_urb *urb, struct bce_vhci_message *msg) +{ + if (urb->state == BCE_VHCI_URB_INIT_PENDING) + return -EAGAIN; + if (urb->is_control) + return bce_vhci_urb_control_update(urb, msg); + else + return bce_vhci_urb_data_update(urb, msg); +} + +static int bce_vhci_urb_transfer_completion(struct bce_vhci_urb *urb, struct bce_sq_completion_data *c) +{ + if (urb->is_control) + return bce_vhci_urb_control_transfer_completion(urb, c); + else + return bce_vhci_urb_data_transfer_completion(urb, c); +} + +static void bce_vhci_urb_resume(struct bce_vhci_urb *urb) +{ + int status = 0; + if (urb->state == BCE_VHCI_URB_WAITING_FOR_COMPLETION) { + status = bce_vhci_urb_data_transfer_in(urb, NULL); + } + if (status) + bce_vhci_urb_complete(urb, status); +} diff --git a/drivers/staging/apple-bce/vhci/transfer.h b/drivers/staging/apple-bce/vhci/transfer.h new file mode 100644 index 000000000000..89ecad6bcf8f --- /dev/null +++ b/drivers/staging/apple-bce/vhci/transfer.h @@ -0,0 +1,73 @@ +#ifndef BCEDRIVER_TRANSFER_H +#define BCEDRIVER_TRANSFER_H + +#include +#include "queue.h" +#include "command.h" +#include "../queue.h" + +struct bce_vhci_list_message { + struct list_head list; + struct bce_vhci_message msg; +}; +enum bce_vhci_pause_source { + BCE_VHCI_PAUSE_INTERNAL_WQ = 1, + BCE_VHCI_PAUSE_FIRMWARE = 2, + BCE_VHCI_PAUSE_SUSPEND = 4, + BCE_VHCI_PAUSE_SHUTDOWN = 8 +}; +struct bce_vhci_transfer_queue { + struct bce_vhci *vhci; + struct usb_host_endpoint *endp; + enum bce_vhci_endpoint_state state; + u32 max_active_requests, remaining_active_requests; + bool active, stalled; + u32 paused_by; + bce_vhci_device_t dev_addr; + u8 endp_addr; + struct bce_queue_cq *cq; + struct bce_queue_sq *sq_in; + struct bce_queue_sq *sq_out; + struct list_head evq; + struct spinlock urb_lock; + struct mutex pause_lock; + struct list_head giveback_urb_list; + + struct work_struct w_reset; +}; +enum bce_vhci_urb_state { + BCE_VHCI_URB_INIT_PENDING, + + BCE_VHCI_URB_WAITING_FOR_TRANSFER_REQUEST, + BCE_VHCI_URB_WAITING_FOR_COMPLETION, + BCE_VHCI_URB_DATA_TRANSFER_COMPLETE, + + BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_REQUEST, + BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_COMPLETION, + BCE_VHCI_URB_CONTROL_COMPLETE +}; +struct bce_vhci_urb { + struct urb *urb; + struct bce_vhci_transfer_queue *q; + enum dma_data_direction dir; + bool is_control; + enum bce_vhci_urb_state state; + int received_status; + u32 send_offset; + u32 receive_offset; +}; + +void bce_vhci_create_transfer_queue(struct bce_vhci *vhci, struct bce_vhci_transfer_queue *q, + struct usb_host_endpoint *endp, bce_vhci_device_t dev_addr, enum dma_data_direction dir); +void bce_vhci_destroy_transfer_queue(struct bce_vhci *vhci, struct bce_vhci_transfer_queue *q); +void bce_vhci_transfer_queue_event(struct bce_vhci_transfer_queue *q, struct bce_vhci_message *msg); +int bce_vhci_transfer_queue_do_pause(struct bce_vhci_transfer_queue *q); +int bce_vhci_transfer_queue_do_resume(struct bce_vhci_transfer_queue *q); +int bce_vhci_transfer_queue_pause(struct bce_vhci_transfer_queue *q, enum bce_vhci_pause_source src); +int bce_vhci_transfer_queue_resume(struct bce_vhci_transfer_queue *q, enum bce_vhci_pause_source src); +void bce_vhci_transfer_queue_request_reset(struct bce_vhci_transfer_queue *q); + +int bce_vhci_urb_create(struct bce_vhci_transfer_queue *q, struct urb *urb); +int bce_vhci_urb_request_cancel(struct bce_vhci_transfer_queue *q, struct urb *urb, int status); + +#endif //BCEDRIVER_TRANSFER_H diff --git a/drivers/staging/apple-bce/vhci/vhci.c b/drivers/staging/apple-bce/vhci/vhci.c new file mode 100644 index 000000000000..eb26f55000d8 --- /dev/null +++ b/drivers/staging/apple-bce/vhci/vhci.c @@ -0,0 +1,759 @@ +#include "vhci.h" +#include "../apple_bce.h" +#include "command.h" +#include +#include +#include +#include + +static dev_t bce_vhci_chrdev; +static struct class *bce_vhci_class; +static const struct hc_driver bce_vhci_driver; +static u16 bce_vhci_port_mask = U16_MAX; + +static int bce_vhci_create_event_queues(struct bce_vhci *vhci); +static void bce_vhci_destroy_event_queues(struct bce_vhci *vhci); +static int bce_vhci_create_message_queues(struct bce_vhci *vhci); +static void bce_vhci_destroy_message_queues(struct bce_vhci *vhci); +static void bce_vhci_handle_firmware_events_w(struct work_struct *ws); +static void bce_vhci_firmware_event_completion(struct bce_queue_sq *sq); + +int bce_vhci_create(struct apple_bce_device *dev, struct bce_vhci *vhci) +{ + int status; + + spin_lock_init(&vhci->hcd_spinlock); + + vhci->dev = dev; + + vhci->vdevt = bce_vhci_chrdev; + vhci->vdev = device_create(bce_vhci_class, dev->dev, vhci->vdevt, NULL, "bce-vhci"); + if (IS_ERR_OR_NULL(vhci->vdev)) { + status = PTR_ERR(vhci->vdev); + goto fail_dev; + } + + if ((status = bce_vhci_create_message_queues(vhci))) + goto fail_mq; + if ((status = bce_vhci_create_event_queues(vhci))) + goto fail_eq; + + vhci->tq_state_wq = alloc_ordered_workqueue("bce-vhci-tq-state", 0); + INIT_WORK(&vhci->w_fw_events, bce_vhci_handle_firmware_events_w); + + vhci->hcd = usb_create_hcd(&bce_vhci_driver, vhci->vdev, "bce-vhci"); + if (!vhci->hcd) { + status = -ENOMEM; + goto fail_hcd; + } + vhci->hcd->self.sysdev = &dev->pci->dev; +#if LINUX_VERSION_CODE < KERNEL_VERSION(5,4,0) + vhci->hcd->self.uses_dma = 1; +#endif + *((struct bce_vhci **) vhci->hcd->hcd_priv) = vhci; + vhci->hcd->speed = HCD_USB2; + + if ((status = usb_add_hcd(vhci->hcd, 0, 0))) + goto fail_hcd; + + return 0; + +fail_hcd: + bce_vhci_destroy_event_queues(vhci); +fail_eq: + bce_vhci_destroy_message_queues(vhci); +fail_mq: + device_destroy(bce_vhci_class, vhci->vdevt); +fail_dev: + if (!status) + status = -EINVAL; + return status; +} + +void bce_vhci_destroy(struct bce_vhci *vhci) +{ + usb_remove_hcd(vhci->hcd); + bce_vhci_destroy_event_queues(vhci); + bce_vhci_destroy_message_queues(vhci); + device_destroy(bce_vhci_class, vhci->vdevt); +} + +struct bce_vhci *bce_vhci_from_hcd(struct usb_hcd *hcd) +{ + return *((struct bce_vhci **) hcd->hcd_priv); +} + +int bce_vhci_start(struct usb_hcd *hcd) +{ + struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); + int status; + u16 port_mask = 0; + bce_vhci_port_t port_no = 0; + if ((status = bce_vhci_cmd_controller_enable(&vhci->cq, 1, &port_mask))) + return status; + vhci->port_mask = port_mask; + vhci->port_power_mask = 0; + if ((status = bce_vhci_cmd_controller_start(&vhci->cq))) + return status; + port_mask = vhci->port_mask; + while (port_mask) { + port_no += 1; + port_mask >>= 1; + } + vhci->port_count = port_no; + return 0; +} + +void bce_vhci_stop(struct usb_hcd *hcd) +{ + struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); + bce_vhci_cmd_controller_disable(&vhci->cq); +} + +static int bce_vhci_hub_status_data(struct usb_hcd *hcd, char *buf) +{ + return 0; +} + +static int bce_vhci_reset_device(struct bce_vhci *vhci, int index, u16 timeout); + +static int bce_vhci_hub_control(struct usb_hcd *hcd, u16 typeReq, u16 wValue, u16 wIndex, char *buf, u16 wLength) +{ + struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); + int status; + struct usb_hub_descriptor *hd; + struct usb_hub_status *hs; + struct usb_port_status *ps; + u32 port_status; + // pr_info("bce-vhci: bce_vhci_hub_control %x %i %i [bufl=%i]\n", typeReq, wValue, wIndex, wLength); + if (typeReq == GetHubDescriptor && wLength >= sizeof(struct usb_hub_descriptor)) { + hd = (struct usb_hub_descriptor *) buf; + memset(hd, 0, sizeof(*hd)); + hd->bDescLength = sizeof(struct usb_hub_descriptor); + hd->bDescriptorType = USB_DT_HUB; + hd->bNbrPorts = (u8) vhci->port_count; + hd->wHubCharacteristics = HUB_CHAR_INDV_PORT_LPSM | HUB_CHAR_INDV_PORT_OCPM; + hd->bPwrOn2PwrGood = 0; + hd->bHubContrCurrent = 0; + return 0; + } else if (typeReq == GetHubStatus && wLength >= sizeof(struct usb_hub_status)) { + hs = (struct usb_hub_status *) buf; + memset(hs, 0, sizeof(*hs)); + hs->wHubStatus = 0; + hs->wHubChange = 0; + return 0; + } else if (typeReq == GetPortStatus && wLength >= 4 /* usb 2.0 */) { + ps = (struct usb_port_status *) buf; + ps->wPortStatus = 0; + ps->wPortChange = 0; + + if (vhci->port_power_mask & BIT(wIndex)) + ps->wPortStatus |= USB_PORT_STAT_POWER; + + if (!(bce_vhci_port_mask & BIT(wIndex))) + return 0; + + if ((status = bce_vhci_cmd_port_status(&vhci->cq, (u8) wIndex, 0, &port_status))) + return status; + + if (port_status & 16) + ps->wPortStatus |= USB_PORT_STAT_ENABLE | USB_PORT_STAT_HIGH_SPEED; + if (port_status & 4) + ps->wPortStatus |= USB_PORT_STAT_CONNECTION; + if (port_status & 2) + ps->wPortStatus |= USB_PORT_STAT_OVERCURRENT; + if (port_status & 8) + ps->wPortStatus |= USB_PORT_STAT_RESET; + if (port_status & 0x60) + ps->wPortStatus |= USB_PORT_STAT_SUSPEND; + + if (port_status & 0x40000) + ps->wPortChange |= USB_PORT_STAT_C_CONNECTION; + + pr_debug("bce-vhci: Translated status %x to %x:%x\n", port_status, ps->wPortStatus, ps->wPortChange); + return 0; + } else if (typeReq == SetPortFeature) { + if (wValue == USB_PORT_FEAT_POWER) { + status = bce_vhci_cmd_port_power_on(&vhci->cq, (u8) wIndex); + /* As far as I am aware, power status is not part of the port status so store it separately */ + if (!status) + vhci->port_power_mask |= BIT(wIndex); + return status; + } + if (wValue == USB_PORT_FEAT_RESET) { + return bce_vhci_reset_device(vhci, wIndex, wValue); + } + if (wValue == USB_PORT_FEAT_SUSPEND) { + /* TODO: Am I supposed to also suspend the endpoints? */ + pr_debug("bce-vhci: Suspending port %i\n", wIndex); + return bce_vhci_cmd_port_suspend(&vhci->cq, (u8) wIndex); + } + } else if (typeReq == ClearPortFeature) { + if (wValue == USB_PORT_FEAT_ENABLE) + return bce_vhci_cmd_port_disable(&vhci->cq, (u8) wIndex); + if (wValue == USB_PORT_FEAT_POWER) { + status = bce_vhci_cmd_port_power_off(&vhci->cq, (u8) wIndex); + if (!status) + vhci->port_power_mask &= ~BIT(wIndex); + return status; + } + if (wValue == USB_PORT_FEAT_C_CONNECTION) + return bce_vhci_cmd_port_status(&vhci->cq, (u8) wIndex, 0x40000, &port_status); + if (wValue == USB_PORT_FEAT_C_RESET) { /* I don't think I can transfer it in any way */ + return 0; + } + if (wValue == USB_PORT_FEAT_SUSPEND) { + pr_debug("bce-vhci: Resuming port %i\n", wIndex); + return bce_vhci_cmd_port_resume(&vhci->cq, (u8) wIndex); + } + } + pr_err("bce-vhci: bce_vhci_hub_control unhandled request: %x %i %i [bufl=%i]\n", typeReq, wValue, wIndex, wLength); + dump_stack(); + return -EIO; +} + +static int bce_vhci_enable_device(struct usb_hcd *hcd, struct usb_device *udev) +{ + struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); + struct bce_vhci_device *vdev; + bce_vhci_device_t devid; + pr_info("bce_vhci_enable_device\n"); + + if (vhci->port_to_device[udev->portnum]) + return 0; + + /* We need to early address the device */ + if (bce_vhci_cmd_device_create(&vhci->cq, udev->portnum, &devid)) + return -EIO; + + pr_info("bce_vhci_cmd_device_create %i -> %i\n", udev->portnum, devid); + + vdev = kzalloc(sizeof(struct bce_vhci_device), GFP_KERNEL); + vhci->port_to_device[udev->portnum] = devid; + vhci->devices[devid] = vdev; + + bce_vhci_create_transfer_queue(vhci, &vdev->tq[0], &udev->ep0, devid, DMA_BIDIRECTIONAL); + udev->ep0.hcpriv = &vdev->tq[0]; + vdev->tq_mask |= BIT(0); + + bce_vhci_cmd_endpoint_create(&vhci->cq, devid, &udev->ep0.desc); + return 0; +} + +static int bce_vhci_address_device(struct usb_hcd *hcd, struct usb_device *udev, unsigned int timeout_ms) //TODO: follow timeout +{ + /* This is the same as enable_device, but instead in the old scheme */ + return bce_vhci_enable_device(hcd, udev); +} + +static void bce_vhci_free_device(struct usb_hcd *hcd, struct usb_device *udev) +{ + struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); + int i; + bce_vhci_device_t devid; + struct bce_vhci_device *dev; + pr_info("bce_vhci_free_device %i\n", udev->portnum); + if (!vhci->port_to_device[udev->portnum]) + return; + devid = vhci->port_to_device[udev->portnum]; + dev = vhci->devices[devid]; + for (i = 0; i < 32; i++) { + if (dev->tq_mask & BIT(i)) { + bce_vhci_transfer_queue_pause(&dev->tq[i], BCE_VHCI_PAUSE_SHUTDOWN); + bce_vhci_cmd_endpoint_destroy(&vhci->cq, devid, (u8) i); + bce_vhci_destroy_transfer_queue(vhci, &dev->tq[i]); + } + } + vhci->devices[devid] = NULL; + vhci->port_to_device[udev->portnum] = 0; + bce_vhci_cmd_device_destroy(&vhci->cq, devid); + kfree(dev); +} + +static int bce_vhci_reset_device(struct bce_vhci *vhci, int index, u16 timeout) +{ + struct bce_vhci_device *dev = NULL; + bce_vhci_device_t devid; + int i; + int status; + enum dma_data_direction dir; + pr_info("bce_vhci_reset_device %i\n", index); + + devid = vhci->port_to_device[index]; + if (devid) { + dev = vhci->devices[devid]; + + for (i = 0; i < 32; i++) { + if (dev->tq_mask & BIT(i)) { + bce_vhci_transfer_queue_pause(&dev->tq[i], BCE_VHCI_PAUSE_SHUTDOWN); + bce_vhci_cmd_endpoint_destroy(&vhci->cq, devid, (u8) i); + bce_vhci_destroy_transfer_queue(vhci, &dev->tq[i]); + } + } + vhci->devices[devid] = NULL; + vhci->port_to_device[index] = 0; + bce_vhci_cmd_device_destroy(&vhci->cq, devid); + } + status = bce_vhci_cmd_port_reset(&vhci->cq, (u8) index, timeout); + + if (dev) { + if ((status = bce_vhci_cmd_device_create(&vhci->cq, index, &devid))) + return status; + vhci->devices[devid] = dev; + vhci->port_to_device[index] = devid; + + for (i = 0; i < 32; i++) { + if (dev->tq_mask & BIT(i)) { + dir = usb_endpoint_dir_in(&dev->tq[i].endp->desc) ? DMA_FROM_DEVICE : DMA_TO_DEVICE; + if (i == 0) + dir = DMA_BIDIRECTIONAL; + bce_vhci_create_transfer_queue(vhci, &dev->tq[i], dev->tq[i].endp, devid, dir); + bce_vhci_cmd_endpoint_create(&vhci->cq, devid, &dev->tq[i].endp->desc); + } + } + } + + return status; +} + +static int bce_vhci_check_bandwidth(struct usb_hcd *hcd, struct usb_device *udev) +{ + return 0; +} + +static int bce_vhci_get_frame_number(struct usb_hcd *hcd) +{ + return 0; +} + +static int bce_vhci_bus_suspend(struct usb_hcd *hcd) +{ + int i, j; + int status; + struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); + pr_info("bce_vhci: suspend started\n"); + + pr_info("bce_vhci: suspend endpoints\n"); + for (i = 0; i < 16; i++) { + if (!vhci->port_to_device[i]) + continue; + for (j = 0; j < 32; j++) { + if (!(vhci->devices[vhci->port_to_device[i]]->tq_mask & BIT(j))) + continue; + bce_vhci_transfer_queue_pause(&vhci->devices[vhci->port_to_device[i]]->tq[j], + BCE_VHCI_PAUSE_SUSPEND); + } + } + + pr_info("bce_vhci: suspend ports\n"); + for (i = 0; i < 16; i++) { + if (!vhci->port_to_device[i]) + continue; + bce_vhci_cmd_port_suspend(&vhci->cq, i); + } + pr_info("bce_vhci: suspend controller\n"); + if ((status = bce_vhci_cmd_controller_pause(&vhci->cq))) + return status; + + bce_vhci_event_queue_pause(&vhci->ev_commands); + bce_vhci_event_queue_pause(&vhci->ev_system); + bce_vhci_event_queue_pause(&vhci->ev_isochronous); + bce_vhci_event_queue_pause(&vhci->ev_interrupt); + bce_vhci_event_queue_pause(&vhci->ev_asynchronous); + pr_info("bce_vhci: suspend done\n"); + return 0; +} + +static int bce_vhci_bus_resume(struct usb_hcd *hcd) +{ + int i, j; + int status; + struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); + pr_info("bce_vhci: resume started\n"); + + bce_vhci_event_queue_resume(&vhci->ev_system); + bce_vhci_event_queue_resume(&vhci->ev_isochronous); + bce_vhci_event_queue_resume(&vhci->ev_interrupt); + bce_vhci_event_queue_resume(&vhci->ev_asynchronous); + bce_vhci_event_queue_resume(&vhci->ev_commands); + + pr_info("bce_vhci: resume controller\n"); + if ((status = bce_vhci_cmd_controller_start(&vhci->cq))) + return status; + + pr_info("bce_vhci: resume ports\n"); + for (i = 0; i < 16; i++) { + if (!vhci->port_to_device[i]) + continue; + bce_vhci_cmd_port_resume(&vhci->cq, i); + } + pr_info("bce_vhci: resume endpoints\n"); + for (i = 0; i < 16; i++) { + if (!vhci->port_to_device[i]) + continue; + for (j = 0; j < 32; j++) { + if (!(vhci->devices[vhci->port_to_device[i]]->tq_mask & BIT(j))) + continue; + bce_vhci_transfer_queue_resume(&vhci->devices[vhci->port_to_device[i]]->tq[j], + BCE_VHCI_PAUSE_SUSPEND); + } + } + + pr_info("bce_vhci: resume done\n"); + return 0; +} + +static int bce_vhci_urb_enqueue(struct usb_hcd *hcd, struct urb *urb, gfp_t mem_flags) +{ + struct bce_vhci_transfer_queue *q = urb->ep->hcpriv; + pr_debug("bce_vhci_urb_enqueue %i:%x\n", q->dev_addr, urb->ep->desc.bEndpointAddress); + if (!q) + return -ENOENT; + return bce_vhci_urb_create(q, urb); +} + +static int bce_vhci_urb_dequeue(struct usb_hcd *hcd, struct urb *urb, int status) +{ + struct bce_vhci_transfer_queue *q = urb->ep->hcpriv; + pr_debug("bce_vhci_urb_dequeue %x\n", urb->ep->desc.bEndpointAddress); + return bce_vhci_urb_request_cancel(q, urb, status); +} + +static void bce_vhci_endpoint_reset(struct usb_hcd *hcd, struct usb_host_endpoint *ep) +{ + struct bce_vhci_transfer_queue *q = ep->hcpriv; + pr_debug("bce_vhci_endpoint_reset\n"); + if (q) + bce_vhci_transfer_queue_request_reset(q); +} + +static u8 bce_vhci_endpoint_index(u8 addr) +{ + if (addr & 0x80) + return (u8) (0x10 + (addr & 0xf)); + return (u8) (addr & 0xf); +} + +static int bce_vhci_add_endpoint(struct usb_hcd *hcd, struct usb_device *udev, struct usb_host_endpoint *endp) +{ + u8 endp_index = bce_vhci_endpoint_index(endp->desc.bEndpointAddress); + struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); + bce_vhci_device_t devid = vhci->port_to_device[udev->portnum]; + struct bce_vhci_device *vdev = vhci->devices[devid]; + pr_debug("bce_vhci_add_endpoint %x/%x:%x\n", udev->portnum, devid, endp_index); + + if (udev->bus->root_hub == udev) /* The USB hub */ + return 0; + if (vdev == NULL) + return -ENODEV; + if (vdev->tq_mask & BIT(endp_index)) { + endp->hcpriv = &vdev->tq[endp_index]; + return 0; + } + + bce_vhci_create_transfer_queue(vhci, &vdev->tq[endp_index], endp, devid, + usb_endpoint_dir_in(&endp->desc) ? DMA_FROM_DEVICE : DMA_TO_DEVICE); + endp->hcpriv = &vdev->tq[endp_index]; + vdev->tq_mask |= BIT(endp_index); + + bce_vhci_cmd_endpoint_create(&vhci->cq, devid, &endp->desc); + return 0; +} + +static int bce_vhci_drop_endpoint(struct usb_hcd *hcd, struct usb_device *udev, struct usb_host_endpoint *endp) +{ + u8 endp_index = bce_vhci_endpoint_index(endp->desc.bEndpointAddress); + struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); + bce_vhci_device_t devid = vhci->port_to_device[udev->portnum]; + struct bce_vhci_transfer_queue *q = endp->hcpriv; + struct bce_vhci_device *vdev = vhci->devices[devid]; + pr_info("bce_vhci_drop_endpoint %x:%x\n", udev->portnum, endp_index); + if (!q) { + if (vdev && vdev->tq_mask & BIT(endp_index)) { + pr_err("something deleted the hcpriv?\n"); + q = &vdev->tq[endp_index]; + } else { + return 0; + } + } + + bce_vhci_cmd_endpoint_destroy(&vhci->cq, devid, (u8) (endp->desc.bEndpointAddress & 0x8Fu)); + vhci->devices[devid]->tq_mask &= ~BIT(endp_index); + bce_vhci_destroy_transfer_queue(vhci, q); + return 0; +} + +static int bce_vhci_create_message_queues(struct bce_vhci *vhci) +{ + if (bce_vhci_message_queue_create(vhci, &vhci->msg_commands, "VHC1HostCommands") || + bce_vhci_message_queue_create(vhci, &vhci->msg_system, "VHC1HostSystemEvents") || + bce_vhci_message_queue_create(vhci, &vhci->msg_isochronous, "VHC1HostIsochronousEvents") || + bce_vhci_message_queue_create(vhci, &vhci->msg_interrupt, "VHC1HostInterruptEvents") || + bce_vhci_message_queue_create(vhci, &vhci->msg_asynchronous, "VHC1HostAsynchronousEvents")) { + bce_vhci_destroy_message_queues(vhci); + return -EINVAL; + } + spin_lock_init(&vhci->msg_asynchronous_lock); + bce_vhci_command_queue_create(&vhci->cq, &vhci->msg_commands); + return 0; +} + +static void bce_vhci_destroy_message_queues(struct bce_vhci *vhci) +{ + bce_vhci_command_queue_destroy(&vhci->cq); + bce_vhci_message_queue_destroy(vhci, &vhci->msg_commands); + bce_vhci_message_queue_destroy(vhci, &vhci->msg_system); + bce_vhci_message_queue_destroy(vhci, &vhci->msg_isochronous); + bce_vhci_message_queue_destroy(vhci, &vhci->msg_interrupt); + bce_vhci_message_queue_destroy(vhci, &vhci->msg_asynchronous); +} + +static void bce_vhci_handle_system_event(struct bce_vhci_event_queue *q, struct bce_vhci_message *msg); +static void bce_vhci_handle_usb_event(struct bce_vhci_event_queue *q, struct bce_vhci_message *msg); + +static int bce_vhci_create_event_queues(struct bce_vhci *vhci) +{ + vhci->ev_cq = bce_create_cq(vhci->dev, 0x100); + if (!vhci->ev_cq) + return -EINVAL; +#define CREATE_EVENT_QUEUE(field, name, cb) bce_vhci_event_queue_create(vhci, &vhci->field, name, cb) + if (__bce_vhci_event_queue_create(vhci, &vhci->ev_commands, "VHC1FirmwareCommands", + bce_vhci_firmware_event_completion) || + CREATE_EVENT_QUEUE(ev_system, "VHC1FirmwareSystemEvents", bce_vhci_handle_system_event) || + CREATE_EVENT_QUEUE(ev_isochronous, "VHC1FirmwareIsochronousEvents", bce_vhci_handle_usb_event) || + CREATE_EVENT_QUEUE(ev_interrupt, "VHC1FirmwareInterruptEvents", bce_vhci_handle_usb_event) || + CREATE_EVENT_QUEUE(ev_asynchronous, "VHC1FirmwareAsynchronousEvents", bce_vhci_handle_usb_event)) { + bce_vhci_destroy_event_queues(vhci); + return -EINVAL; + } +#undef CREATE_EVENT_QUEUE + return 0; +} + +static void bce_vhci_destroy_event_queues(struct bce_vhci *vhci) +{ + bce_vhci_event_queue_destroy(vhci, &vhci->ev_commands); + bce_vhci_event_queue_destroy(vhci, &vhci->ev_system); + bce_vhci_event_queue_destroy(vhci, &vhci->ev_isochronous); + bce_vhci_event_queue_destroy(vhci, &vhci->ev_interrupt); + bce_vhci_event_queue_destroy(vhci, &vhci->ev_asynchronous); + if (vhci->ev_cq) + bce_destroy_cq(vhci->dev, vhci->ev_cq); +} + +static void bce_vhci_send_fw_event_response(struct bce_vhci *vhci, struct bce_vhci_message *req, u16 status) +{ + unsigned long timeout = 1000; + struct bce_vhci_message r = *req; + r.cmd = (u16) (req->cmd | 0x8000u); + r.status = status; + r.param1 = req->param1; + r.param2 = 0; + + if (bce_reserve_submission(vhci->msg_system.sq, &timeout)) { + pr_err("bce-vhci: Cannot reserve submision for FW event reply\n"); + return; + } + bce_vhci_message_queue_write(&vhci->msg_system, &r); +} + +static int bce_vhci_handle_firmware_event(struct bce_vhci *vhci, struct bce_vhci_message *msg) +{ + unsigned long flags; + bce_vhci_device_t devid; + u8 endp; + struct bce_vhci_device *dev; + struct bce_vhci_transfer_queue *tq; + if (msg->cmd == BCE_VHCI_CMD_ENDPOINT_REQUEST_STATE || msg->cmd == BCE_VHCI_CMD_ENDPOINT_SET_STATE) { + devid = (bce_vhci_device_t) (msg->param1 & 0xff); + endp = bce_vhci_endpoint_index((u8) ((msg->param1 >> 8) & 0xff)); + dev = vhci->devices[devid]; + if (!dev || !(dev->tq_mask & BIT(endp))) + return BCE_VHCI_BAD_ARGUMENT; + tq = &dev->tq[endp]; + } + + if (msg->cmd == BCE_VHCI_CMD_ENDPOINT_REQUEST_STATE) { + if (msg->param2 == BCE_VHCI_ENDPOINT_ACTIVE) { + bce_vhci_transfer_queue_resume(tq, BCE_VHCI_PAUSE_FIRMWARE); + return BCE_VHCI_SUCCESS; + } else if (msg->param2 == BCE_VHCI_ENDPOINT_PAUSED) { + bce_vhci_transfer_queue_pause(tq, BCE_VHCI_PAUSE_FIRMWARE); + return BCE_VHCI_SUCCESS; + } + return BCE_VHCI_BAD_ARGUMENT; + } else if (msg->cmd == BCE_VHCI_CMD_ENDPOINT_SET_STATE) { + if (msg->param2 == BCE_VHCI_ENDPOINT_STALLED) { + tq->state = msg->param2; + spin_lock_irqsave(&tq->urb_lock, flags); + tq->stalled = true; + spin_unlock_irqrestore(&tq->urb_lock, flags); + return BCE_VHCI_SUCCESS; + } + return BCE_VHCI_BAD_ARGUMENT; + } + pr_warn("bce-vhci: Unhandled firmware event: %x s=%x p1=%x p2=%llx\n", + msg->cmd, msg->status, msg->param1, msg->param2); + return BCE_VHCI_BAD_ARGUMENT; +} + +static void bce_vhci_handle_firmware_events_w(struct work_struct *ws) +{ + size_t cnt = 0; + int result; + struct bce_vhci *vhci = container_of(ws, struct bce_vhci, w_fw_events); + struct bce_queue_sq *sq = vhci->ev_commands.sq; + struct bce_sq_completion_data *cq; + struct bce_vhci_message *msg, *msg2 = NULL; + + while (true) { + if (msg2) { + msg = msg2; + msg2 = NULL; + } else if ((cq = bce_next_completion(sq))) { + if (cq->status == BCE_COMPLETION_ABORTED) { + bce_notify_submission_complete(sq); + continue; + } + msg = &vhci->ev_commands.data[sq->head]; + } else { + break; + } + + pr_debug("bce-vhci: Got fw event: %x s=%x p1=%x p2=%llx\n", msg->cmd, msg->status, msg->param1, msg->param2); + if ((cq = bce_next_completion(sq))) { + msg2 = &vhci->ev_commands.data[(sq->head + 1) % sq->el_count]; + pr_debug("bce-vhci: Got second fw event: %x s=%x p1=%x p2=%llx\n", + msg->cmd, msg->status, msg->param1, msg->param2); + if (cq->status != BCE_COMPLETION_ABORTED && + msg2->cmd == (msg->cmd | 0x4000) && msg2->param1 == msg->param1) { + /* Take two elements */ + pr_debug("bce-vhci: Cancelled\n"); + bce_vhci_send_fw_event_response(vhci, msg, BCE_VHCI_ABORT); + + bce_notify_submission_complete(sq); + bce_notify_submission_complete(sq); + msg2 = NULL; + cnt += 2; + continue; + } + + pr_warn("bce-vhci: Handle fw event - unexpected cancellation\n"); + } + + result = bce_vhci_handle_firmware_event(vhci, msg); + bce_vhci_send_fw_event_response(vhci, msg, (u16) result); + + + bce_notify_submission_complete(sq); + ++cnt; + } + bce_vhci_event_queue_submit_pending(&vhci->ev_commands, cnt); + if (atomic_read(&sq->available_commands) == sq->el_count - 1) { + pr_debug("bce-vhci: complete\n"); + complete(&vhci->ev_commands.queue_empty_completion); + } +} + +static void bce_vhci_firmware_event_completion(struct bce_queue_sq *sq) +{ + struct bce_vhci_event_queue *q = sq->userdata; + queue_work(q->vhci->tq_state_wq, &q->vhci->w_fw_events); +} + +static void bce_vhci_handle_system_event(struct bce_vhci_event_queue *q, struct bce_vhci_message *msg) +{ + if (msg->cmd & 0x8000) { + bce_vhci_command_queue_deliver_completion(&q->vhci->cq, msg); + } else { + pr_warn("bce-vhci: Unhandled system event: %x s=%x p1=%x p2=%llx\n", + msg->cmd, msg->status, msg->param1, msg->param2); + } +} + +static void bce_vhci_handle_usb_event(struct bce_vhci_event_queue *q, struct bce_vhci_message *msg) +{ + bce_vhci_device_t devid; + u8 endp; + struct bce_vhci_device *dev; + if (msg->cmd & 0x8000) { + bce_vhci_command_queue_deliver_completion(&q->vhci->cq, msg); + } else if (msg->cmd == BCE_VHCI_CMD_TRANSFER_REQUEST || msg->cmd == BCE_VHCI_CMD_CONTROL_TRANSFER_STATUS) { + devid = (bce_vhci_device_t) (msg->param1 & 0xff); + endp = bce_vhci_endpoint_index((u8) ((msg->param1 >> 8) & 0xff)); + dev = q->vhci->devices[devid]; + if (!dev || (dev->tq_mask & BIT(endp)) == 0) { + pr_err("bce-vhci: Didn't find destination for transfer queue event\n"); + return; + } + bce_vhci_transfer_queue_event(&dev->tq[endp], msg); + } else { + pr_warn("bce-vhci: Unhandled USB event: %x s=%x p1=%x p2=%llx\n", + msg->cmd, msg->status, msg->param1, msg->param2); + } +} + + + +static const struct hc_driver bce_vhci_driver = { + .description = "bce-vhci", + .product_desc = "BCE VHCI Host Controller", + .hcd_priv_size = sizeof(struct bce_vhci *), + +#if LINUX_VERSION_CODE < KERNEL_VERSION(5,4,0) + .flags = HCD_USB2, +#else + .flags = HCD_USB2 | HCD_DMA, +#endif + + .start = bce_vhci_start, + .stop = bce_vhci_stop, + .hub_status_data = bce_vhci_hub_status_data, + .hub_control = bce_vhci_hub_control, + .urb_enqueue = bce_vhci_urb_enqueue, + .urb_dequeue = bce_vhci_urb_dequeue, + .enable_device = bce_vhci_enable_device, + .free_dev = bce_vhci_free_device, + .address_device = bce_vhci_address_device, + .add_endpoint = bce_vhci_add_endpoint, + .drop_endpoint = bce_vhci_drop_endpoint, + .endpoint_reset = bce_vhci_endpoint_reset, + .check_bandwidth = bce_vhci_check_bandwidth, + .get_frame_number = bce_vhci_get_frame_number, + .bus_suspend = bce_vhci_bus_suspend, + .bus_resume = bce_vhci_bus_resume +}; + + +int __init bce_vhci_module_init(void) +{ + int result; + if ((result = alloc_chrdev_region(&bce_vhci_chrdev, 0, 1, "bce-vhci"))) + goto fail_chrdev; +#if LINUX_VERSION_CODE < KERNEL_VERSION(6,4,0) + bce_vhci_class = class_create(THIS_MODULE, "bce-vhci"); +#else + bce_vhci_class = class_create("bce-vhci"); +#endif + if (IS_ERR(bce_vhci_class)) { + result = PTR_ERR(bce_vhci_class); + goto fail_class; + } + return 0; + +fail_class: + class_destroy(bce_vhci_class); +fail_chrdev: + unregister_chrdev_region(bce_vhci_chrdev, 1); + if (!result) + result = -EINVAL; + return result; +} +void __exit bce_vhci_module_exit(void) +{ + class_destroy(bce_vhci_class); + unregister_chrdev_region(bce_vhci_chrdev, 1); +} + +module_param_named(vhci_port_mask, bce_vhci_port_mask, ushort, 0444); +MODULE_PARM_DESC(vhci_port_mask, "Specifies which VHCI ports are enabled"); diff --git a/drivers/staging/apple-bce/vhci/vhci.h b/drivers/staging/apple-bce/vhci/vhci.h new file mode 100644 index 000000000000..6c2e22622f4c --- /dev/null +++ b/drivers/staging/apple-bce/vhci/vhci.h @@ -0,0 +1,52 @@ +#ifndef BCE_VHCI_H +#define BCE_VHCI_H + +#include "queue.h" +#include "transfer.h" + +struct usb_hcd; +struct bce_queue_cq; + +struct bce_vhci_device { + struct bce_vhci_transfer_queue tq[32]; + u32 tq_mask; +}; +struct bce_vhci { + struct apple_bce_device *dev; + dev_t vdevt; + struct device *vdev; + struct usb_hcd *hcd; + struct spinlock hcd_spinlock; + struct bce_vhci_message_queue msg_commands; + struct bce_vhci_message_queue msg_system; + struct bce_vhci_message_queue msg_isochronous; + struct bce_vhci_message_queue msg_interrupt; + struct bce_vhci_message_queue msg_asynchronous; + struct spinlock msg_asynchronous_lock; + struct bce_vhci_command_queue cq; + struct bce_queue_cq *ev_cq; + struct bce_vhci_event_queue ev_commands; + struct bce_vhci_event_queue ev_system; + struct bce_vhci_event_queue ev_isochronous; + struct bce_vhci_event_queue ev_interrupt; + struct bce_vhci_event_queue ev_asynchronous; + u16 port_mask; + u8 port_count; + u16 port_power_mask; + bce_vhci_device_t port_to_device[16]; + struct bce_vhci_device *devices[16]; + struct workqueue_struct *tq_state_wq; + struct work_struct w_fw_events; +}; + +int __init bce_vhci_module_init(void); +void __exit bce_vhci_module_exit(void); + +int bce_vhci_create(struct apple_bce_device *dev, struct bce_vhci *vhci); +void bce_vhci_destroy(struct bce_vhci *vhci); +int bce_vhci_start(struct usb_hcd *hcd); +void bce_vhci_stop(struct usb_hcd *hcd); + +struct bce_vhci *bce_vhci_from_hcd(struct usb_hcd *hcd); + +#endif //BCE_VHCI_H diff --git a/include/drm/drm_format_helper.h b/include/drm/drm_format_helper.h index 428d81afe215..aa1604d92c1a 100644 --- a/include/drm/drm_format_helper.h +++ b/include/drm/drm_format_helper.h @@ -96,6 +96,9 @@ void drm_fb_xrgb8888_to_rgba5551(struct iosys_map *dst, const unsigned int *dst_ void drm_fb_xrgb8888_to_rgb888(struct iosys_map *dst, const unsigned int *dst_pitch, const struct iosys_map *src, const struct drm_framebuffer *fb, const struct drm_rect *clip, struct drm_format_conv_state *state); +void drm_fb_xrgb8888_to_bgr888(struct iosys_map *dst, const unsigned int *dst_pitch, + const struct iosys_map *src, const struct drm_framebuffer *fb, + const struct drm_rect *clip, struct drm_format_conv_state *state); void drm_fb_xrgb8888_to_argb8888(struct iosys_map *dst, const unsigned int *dst_pitch, const struct iosys_map *src, const struct drm_framebuffer *fb, const struct drm_rect *clip, struct drm_format_conv_state *state); diff --git a/include/linux/hid.h b/include/linux/hid.h index 9ca7e26ac4e9..593b21bd64ff 100644 --- a/include/linux/hid.h +++ b/include/linux/hid.h @@ -590,7 +590,9 @@ struct hid_input { enum hid_type { HID_TYPE_OTHER = 0, HID_TYPE_USBMOUSE, - HID_TYPE_USBNONE + HID_TYPE_USBNONE, + HID_TYPE_SPI_KEYBOARD, + HID_TYPE_SPI_MOUSE, }; enum hid_battery_status { @@ -750,6 +752,8 @@ struct hid_descriptor { .bus = BUS_BLUETOOTH, .vendor = (ven), .product = (prod) #define HID_I2C_DEVICE(ven, prod) \ .bus = BUS_I2C, .vendor = (ven), .product = (prod) +#define HID_SPI_DEVICE(ven, prod) \ + .bus = BUS_SPI, .vendor = (ven), .product = (prod) #define HID_REPORT_ID(rep) \ .report_type = (rep) diff --git a/include/linux/soc/apple/dockchannel.h b/include/linux/soc/apple/dockchannel.h new file mode 100644 index 000000000000..0b7093935ddf --- /dev/null +++ b/include/linux/soc/apple/dockchannel.h @@ -0,0 +1,26 @@ +/* SPDX-License-Identifier: GPL-2.0-only OR MIT */ +/* + * Apple Dockchannel devices + * Copyright (C) The Asahi Linux Contributors + */ +#ifndef _LINUX_APPLE_DOCKCHANNEL_H_ +#define _LINUX_APPLE_DOCKCHANNEL_H_ + +#include +#include +#include + +#if IS_ENABLED(CONFIG_APPLE_DOCKCHANNEL) + +struct dockchannel; + +struct dockchannel *dockchannel_init(struct platform_device *pdev); + +int dockchannel_send(struct dockchannel *dockchannel, const void *buf, size_t count); +int dockchannel_recv(struct dockchannel *dockchannel, void *buf, size_t count); +int dockchannel_await(struct dockchannel *dockchannel, + void (*callback)(void *cookie, size_t avail), + void *cookie, size_t count); + +#endif +#endif diff --git a/include/linux/soc/apple/rtkit.h b/include/linux/soc/apple/rtkit.h index c06d17599ae7..736f53018017 100644 --- a/include/linux/soc/apple/rtkit.h +++ b/include/linux/soc/apple/rtkit.h @@ -56,7 +56,7 @@ struct apple_rtkit_shmem { * context. */ struct apple_rtkit_ops { - void (*crashed)(void *cookie); + void (*crashed)(void *cookie, const void *crashlog, size_t crashlog_size); void (*recv_message)(void *cookie, u8 endpoint, u64 message); bool (*recv_message_early)(void *cookie, u8 endpoint, u64 message); int (*shmem_setup)(void *cookie, struct apple_rtkit_shmem *bfr); -- 2.49.0.391.g4bbb303af6 From a9a81256578740771f346d17be439ace9250643a Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Sun, 20 Apr 2025 15:34:05 +0200 Subject: [PATCH 9/9] zstd Signed-off-by: Peter Jung --- MAINTAINERS | 1 + include/linux/zstd.h | 87 +- include/linux/zstd_errors.h | 30 +- include/linux/zstd_lib.h | 1123 ++++-- lib/zstd/Makefile | 3 +- lib/zstd/common/allocations.h | 56 + lib/zstd/common/bits.h | 150 + lib/zstd/common/bitstream.h | 155 +- lib/zstd/common/compiler.h | 151 +- lib/zstd/common/cpu.h | 3 +- lib/zstd/common/debug.c | 9 +- lib/zstd/common/debug.h | 37 +- lib/zstd/common/entropy_common.c | 42 +- lib/zstd/common/error_private.c | 13 +- lib/zstd/common/error_private.h | 88 +- lib/zstd/common/fse.h | 103 +- lib/zstd/common/fse_decompress.c | 132 +- lib/zstd/common/huf.h | 240 +- lib/zstd/common/mem.h | 3 +- lib/zstd/common/portability_macros.h | 51 +- lib/zstd/common/zstd_common.c | 38 +- lib/zstd/common/zstd_deps.h | 16 +- lib/zstd/common/zstd_internal.h | 153 +- lib/zstd/compress/clevels.h | 3 +- lib/zstd/compress/fse_compress.c | 74 +- lib/zstd/compress/hist.c | 13 +- lib/zstd/compress/hist.h | 10 +- lib/zstd/compress/huf_compress.c | 441 ++- lib/zstd/compress/zstd_compress.c | 3293 ++++++++++++----- lib/zstd/compress/zstd_compress_internal.h | 621 +++- lib/zstd/compress/zstd_compress_literals.c | 157 +- lib/zstd/compress/zstd_compress_literals.h | 25 +- lib/zstd/compress/zstd_compress_sequences.c | 21 +- lib/zstd/compress/zstd_compress_sequences.h | 16 +- lib/zstd/compress/zstd_compress_superblock.c | 394 +- lib/zstd/compress/zstd_compress_superblock.h | 3 +- lib/zstd/compress/zstd_cwksp.h | 222 +- lib/zstd/compress/zstd_double_fast.c | 245 +- lib/zstd/compress/zstd_double_fast.h | 27 +- lib/zstd/compress/zstd_fast.c | 703 +++- lib/zstd/compress/zstd_fast.h | 16 +- lib/zstd/compress/zstd_lazy.c | 840 +++-- lib/zstd/compress/zstd_lazy.h | 195 +- lib/zstd/compress/zstd_ldm.c | 102 +- lib/zstd/compress/zstd_ldm.h | 17 +- lib/zstd/compress/zstd_ldm_geartab.h | 3 +- lib/zstd/compress/zstd_opt.c | 571 +-- lib/zstd/compress/zstd_opt.h | 55 +- lib/zstd/compress/zstd_preSplit.c | 239 ++ lib/zstd/compress/zstd_preSplit.h | 34 + lib/zstd/decompress/huf_decompress.c | 887 +++-- lib/zstd/decompress/zstd_ddict.c | 9 +- lib/zstd/decompress/zstd_ddict.h | 3 +- lib/zstd/decompress/zstd_decompress.c | 377 +- lib/zstd/decompress/zstd_decompress_block.c | 724 ++-- lib/zstd/decompress/zstd_decompress_block.h | 10 +- .../decompress/zstd_decompress_internal.h | 19 +- lib/zstd/decompress_sources.h | 2 +- lib/zstd/zstd_common_module.c | 5 +- lib/zstd/zstd_compress_module.c | 75 +- lib/zstd/zstd_decompress_module.c | 4 +- 61 files changed, 8755 insertions(+), 4384 deletions(-) create mode 100644 lib/zstd/common/allocations.h create mode 100644 lib/zstd/common/bits.h create mode 100644 lib/zstd/compress/zstd_preSplit.c create mode 100644 lib/zstd/compress/zstd_preSplit.h diff --git a/MAINTAINERS b/MAINTAINERS index 2b1f3e8bdbdd..11b706953286 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -26310,6 +26310,7 @@ F: mm/zsmalloc.c ZSTD M: Nick Terrell +M: David Sterba S: Maintained B: https://github.com/facebook/zstd/issues T: git https://github.com/terrelln/linux.git diff --git a/include/linux/zstd.h b/include/linux/zstd.h index b2c7cf310c8f..2f2a3c8b8a33 100644 --- a/include/linux/zstd.h +++ b/include/linux/zstd.h @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -160,7 +160,6 @@ typedef ZSTD_parameters zstd_parameters; zstd_parameters zstd_get_params(int level, unsigned long long estimated_src_size); - /** * zstd_get_cparams() - returns zstd_compression_parameters for selected level * @level: The compression level @@ -173,9 +172,20 @@ zstd_parameters zstd_get_params(int level, zstd_compression_parameters zstd_get_cparams(int level, unsigned long long estimated_src_size, size_t dict_size); -/* ====== Single-pass Compression ====== */ - typedef ZSTD_CCtx zstd_cctx; +typedef ZSTD_cParameter zstd_cparameter; + +/** + * zstd_cctx_set_param() - sets a compression parameter + * @cctx: The context. Must have been initialized with zstd_init_cctx(). + * @param: The parameter to set. + * @value: The value to set the parameter to. + * + * Return: Zero or an error, which can be checked using zstd_is_error(). + */ +size_t zstd_cctx_set_param(zstd_cctx *cctx, zstd_cparameter param, int value); + +/* ====== Single-pass Compression ====== */ /** * zstd_cctx_workspace_bound() - max memory needed to initialize a zstd_cctx @@ -190,6 +200,20 @@ typedef ZSTD_CCtx zstd_cctx; */ size_t zstd_cctx_workspace_bound(const zstd_compression_parameters *parameters); +/** + * zstd_cctx_workspace_bound_with_ext_seq_prod() - max memory needed to + * initialize a zstd_cctx when using the block-level external sequence + * producer API. + * @parameters: The compression parameters to be used. + * + * If multiple compression parameters might be used, the caller must call + * this function for each set of parameters and use the maximum size. + * + * Return: A lower bound on the size of the workspace that is passed to + * zstd_init_cctx(). + */ +size_t zstd_cctx_workspace_bound_with_ext_seq_prod(const zstd_compression_parameters *parameters); + /** * zstd_init_cctx() - initialize a zstd compression context * @workspace: The workspace to emplace the context into. It must outlive @@ -424,6 +448,16 @@ typedef ZSTD_CStream zstd_cstream; */ size_t zstd_cstream_workspace_bound(const zstd_compression_parameters *cparams); +/** + * zstd_cstream_workspace_bound_with_ext_seq_prod() - memory needed to initialize + * a zstd_cstream when using the block-level external sequence producer API. + * @cparams: The compression parameters to be used for compression. + * + * Return: A lower bound on the size of the workspace that is passed to + * zstd_init_cstream(). + */ +size_t zstd_cstream_workspace_bound_with_ext_seq_prod(const zstd_compression_parameters *cparams); + /** * zstd_init_cstream() - initialize a zstd streaming compression context * @parameters The zstd parameters to use for compression. @@ -583,6 +617,18 @@ size_t zstd_decompress_stream(zstd_dstream *dstream, zstd_out_buffer *output, */ size_t zstd_find_frame_compressed_size(const void *src, size_t src_size); +/** + * zstd_register_sequence_producer() - exposes the zstd library function + * ZSTD_registerSequenceProducer(). This is used for the block-level external + * sequence producer API. See upstream zstd.h for detailed documentation. + */ +typedef ZSTD_sequenceProducer_F zstd_sequence_producer_f; +void zstd_register_sequence_producer( + zstd_cctx *cctx, + void* sequence_producer_state, + zstd_sequence_producer_f sequence_producer +); + /** * struct zstd_frame_params - zstd frame parameters stored in the frame header * @frameContentSize: The frame content size, or ZSTD_CONTENTSIZE_UNKNOWN if not @@ -596,7 +642,7 @@ size_t zstd_find_frame_compressed_size(const void *src, size_t src_size); * * See zstd_lib.h. */ -typedef ZSTD_frameHeader zstd_frame_header; +typedef ZSTD_FrameHeader zstd_frame_header; /** * zstd_get_frame_header() - extracts parameters from a zstd or skippable frame @@ -611,4 +657,35 @@ typedef ZSTD_frameHeader zstd_frame_header; size_t zstd_get_frame_header(zstd_frame_header *params, const void *src, size_t src_size); +/** + * struct zstd_sequence - a sequence of literals or a match + * + * @offset: The offset of the match + * @litLength: The literal length of the sequence + * @matchLength: The match length of the sequence + * @rep: Represents which repeat offset is used + */ +typedef ZSTD_Sequence zstd_sequence; + +/** + * zstd_compress_sequences_and_literals() - compress an array of zstd_sequence and literals + * + * @cctx: The zstd compression context. + * @dst: The buffer to compress the data into. + * @dst_capacity: The size of the destination buffer. + * @in_seqs: The array of zstd_sequence to compress. + * @in_seqs_size: The number of sequences in in_seqs. + * @literals: The literals associated to the sequences to be compressed. + * @lit_size: The size of the literals in the literals buffer. + * @lit_capacity: The size of the literals buffer. + * @decompressed_size: The size of the input data + * + * Return: The compressed size or an error, which can be checked using + * zstd_is_error(). + */ +size_t zstd_compress_sequences_and_literals(zstd_cctx *cctx, void* dst, size_t dst_capacity, + const zstd_sequence *in_seqs, size_t in_seqs_size, + const void* literals, size_t lit_size, size_t lit_capacity, + size_t decompressed_size); + #endif /* LINUX_ZSTD_H */ diff --git a/include/linux/zstd_errors.h b/include/linux/zstd_errors.h index 58b6dd45a969..c307fb011132 100644 --- a/include/linux/zstd_errors.h +++ b/include/linux/zstd_errors.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -12,13 +13,18 @@ #define ZSTD_ERRORS_H_398273423 -/*===== dependency =====*/ -#include /* size_t */ +/* ===== ZSTDERRORLIB_API : control library symbols visibility ===== */ +#define ZSTDERRORLIB_VISIBLE +#ifndef ZSTDERRORLIB_HIDDEN +# if (__GNUC__ >= 4) && !defined(__MINGW32__) +# define ZSTDERRORLIB_HIDDEN __attribute__ ((visibility ("hidden"))) +# else +# define ZSTDERRORLIB_HIDDEN +# endif +#endif -/* ===== ZSTDERRORLIB_API : control library symbols visibility ===== */ -#define ZSTDERRORLIB_VISIBILITY -#define ZSTDERRORLIB_API ZSTDERRORLIB_VISIBILITY +#define ZSTDERRORLIB_API ZSTDERRORLIB_VISIBLE /*-********************************************* * Error codes list @@ -43,14 +49,18 @@ typedef enum { ZSTD_error_frameParameter_windowTooLarge = 16, ZSTD_error_corruption_detected = 20, ZSTD_error_checksum_wrong = 22, + ZSTD_error_literals_headerWrong = 24, ZSTD_error_dictionary_corrupted = 30, ZSTD_error_dictionary_wrong = 32, ZSTD_error_dictionaryCreation_failed = 34, ZSTD_error_parameter_unsupported = 40, + ZSTD_error_parameter_combination_unsupported = 41, ZSTD_error_parameter_outOfBound = 42, ZSTD_error_tableLog_tooLarge = 44, ZSTD_error_maxSymbolValue_tooLarge = 46, ZSTD_error_maxSymbolValue_tooSmall = 48, + ZSTD_error_cannotProduce_uncompressedBlock = 49, + ZSTD_error_stabilityCondition_notRespected = 50, ZSTD_error_stage_wrong = 60, ZSTD_error_init_missing = 62, ZSTD_error_memory_allocation = 64, @@ -58,18 +68,18 @@ typedef enum { ZSTD_error_dstSize_tooSmall = 70, ZSTD_error_srcSize_wrong = 72, ZSTD_error_dstBuffer_null = 74, + ZSTD_error_noForwardProgress_destFull = 80, + ZSTD_error_noForwardProgress_inputEmpty = 82, /* following error codes are __NOT STABLE__, they can be removed or changed in future versions */ ZSTD_error_frameIndex_tooLarge = 100, ZSTD_error_seekableIO = 102, ZSTD_error_dstBuffer_wrong = 104, ZSTD_error_srcBuffer_wrong = 105, + ZSTD_error_sequenceProducer_failed = 106, + ZSTD_error_externalSequences_invalid = 107, ZSTD_error_maxCode = 120 /* never EVER use this value directly, it can change in future versions! Use ZSTD_isError() instead */ } ZSTD_ErrorCode; -/*! ZSTD_getErrorCode() : - convert a `size_t` function result into a `ZSTD_ErrorCode` enum type, - which can be used to compare with enum list published above */ -ZSTDERRORLIB_API ZSTD_ErrorCode ZSTD_getErrorCode(size_t functionResult); ZSTDERRORLIB_API const char* ZSTD_getErrorString(ZSTD_ErrorCode code); /*< Same as ZSTD_getErrorName, but using a `ZSTD_ErrorCode` enum argument */ diff --git a/include/linux/zstd_lib.h b/include/linux/zstd_lib.h index 79d55465d5c1..e295d4125dde 100644 --- a/include/linux/zstd_lib.h +++ b/include/linux/zstd_lib.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -11,23 +12,47 @@ #ifndef ZSTD_H_235446 #define ZSTD_H_235446 -/* ====== Dependency ======*/ -#include /* INT_MAX */ + +/* ====== Dependencies ======*/ #include /* size_t */ +#include /* list of errors */ +#if !defined(ZSTD_H_ZSTD_STATIC_LINKING_ONLY) +#include /* INT_MAX */ +#endif /* ZSTD_STATIC_LINKING_ONLY */ + /* ===== ZSTDLIB_API : control library symbols visibility ===== */ -#ifndef ZSTDLIB_VISIBLE +#define ZSTDLIB_VISIBLE + +#ifndef ZSTDLIB_HIDDEN # if (__GNUC__ >= 4) && !defined(__MINGW32__) -# define ZSTDLIB_VISIBLE __attribute__ ((visibility ("default"))) # define ZSTDLIB_HIDDEN __attribute__ ((visibility ("hidden"))) # else -# define ZSTDLIB_VISIBLE # define ZSTDLIB_HIDDEN # endif #endif + #define ZSTDLIB_API ZSTDLIB_VISIBLE +/* Deprecation warnings : + * Should these warnings be a problem, it is generally possible to disable them, + * typically with -Wno-deprecated-declarations for gcc or _CRT_SECURE_NO_WARNINGS in Visual. + * Otherwise, it's also possible to define ZSTD_DISABLE_DEPRECATE_WARNINGS. + */ +#ifdef ZSTD_DISABLE_DEPRECATE_WARNINGS +# define ZSTD_DEPRECATED(message) /* disable deprecation warnings */ +#else +# if (defined(GNUC) && (GNUC > 4 || (GNUC == 4 && GNUC_MINOR >= 5))) || defined(__clang__) || defined(__IAR_SYSTEMS_ICC__) +# define ZSTD_DEPRECATED(message) __attribute__((deprecated(message))) +# elif (__GNUC__ >= 3) +# define ZSTD_DEPRECATED(message) __attribute__((deprecated)) +# else +# pragma message("WARNING: You need to implement ZSTD_DEPRECATED for this compiler") +# define ZSTD_DEPRECATED(message) +# endif +#endif /* ZSTD_DISABLE_DEPRECATE_WARNINGS */ + /* ***************************************************************************** Introduction @@ -65,7 +90,7 @@ /*------ Version ------*/ #define ZSTD_VERSION_MAJOR 1 #define ZSTD_VERSION_MINOR 5 -#define ZSTD_VERSION_RELEASE 2 +#define ZSTD_VERSION_RELEASE 7 #define ZSTD_VERSION_NUMBER (ZSTD_VERSION_MAJOR *100*100 + ZSTD_VERSION_MINOR *100 + ZSTD_VERSION_RELEASE) /*! ZSTD_versionNumber() : @@ -103,11 +128,12 @@ ZSTDLIB_API const char* ZSTD_versionString(void); /* ************************************* -* Simple API +* Simple Core API ***************************************/ /*! ZSTD_compress() : * Compresses `src` content as a single zstd compressed frame into already allocated `dst`. - * Hint : compression runs faster if `dstCapacity` >= `ZSTD_compressBound(srcSize)`. + * NOTE: Providing `dstCapacity >= ZSTD_compressBound(srcSize)` guarantees that zstd will have + * enough space to successfully compress the data. * @return : compressed size written into `dst` (<= `dstCapacity), * or an error code if it fails (which can be tested using ZSTD_isError()). */ ZSTDLIB_API size_t ZSTD_compress( void* dst, size_t dstCapacity, @@ -115,47 +141,55 @@ ZSTDLIB_API size_t ZSTD_compress( void* dst, size_t dstCapacity, int compressionLevel); /*! ZSTD_decompress() : - * `compressedSize` : must be the _exact_ size of some number of compressed and/or skippable frames. - * `dstCapacity` is an upper bound of originalSize to regenerate. - * If user cannot imply a maximum upper bound, it's better to use streaming mode to decompress data. - * @return : the number of bytes decompressed into `dst` (<= `dstCapacity`), - * or an errorCode if it fails (which can be tested using ZSTD_isError()). */ + * `compressedSize` : must be the _exact_ size of some number of compressed and/or skippable frames. + * Multiple compressed frames can be decompressed at once with this method. + * The result will be the concatenation of all decompressed frames, back to back. + * `dstCapacity` is an upper bound of originalSize to regenerate. + * First frame's decompressed size can be extracted using ZSTD_getFrameContentSize(). + * If maximum upper bound isn't known, prefer using streaming mode to decompress data. + * @return : the number of bytes decompressed into `dst` (<= `dstCapacity`), + * or an errorCode if it fails (which can be tested using ZSTD_isError()). */ ZSTDLIB_API size_t ZSTD_decompress( void* dst, size_t dstCapacity, const void* src, size_t compressedSize); + +/*====== Decompression helper functions ======*/ + /*! ZSTD_getFrameContentSize() : requires v1.3.0+ - * `src` should point to the start of a ZSTD encoded frame. - * `srcSize` must be at least as large as the frame header. - * hint : any size >= `ZSTD_frameHeaderSize_max` is large enough. - * @return : - decompressed size of `src` frame content, if known - * - ZSTD_CONTENTSIZE_UNKNOWN if the size cannot be determined - * - ZSTD_CONTENTSIZE_ERROR if an error occurred (e.g. invalid magic number, srcSize too small) - * note 1 : a 0 return value means the frame is valid but "empty". - * note 2 : decompressed size is an optional field, it may not be present, typically in streaming mode. - * When `return==ZSTD_CONTENTSIZE_UNKNOWN`, data to decompress could be any size. - * In which case, it's necessary to use streaming mode to decompress data. - * Optionally, application can rely on some implicit limit, - * as ZSTD_decompress() only needs an upper bound of decompressed size. - * (For example, data could be necessarily cut into blocks <= 16 KB). - * note 3 : decompressed size is always present when compression is completed using single-pass functions, - * such as ZSTD_compress(), ZSTD_compressCCtx() ZSTD_compress_usingDict() or ZSTD_compress_usingCDict(). - * note 4 : decompressed size can be very large (64-bits value), - * potentially larger than what local system can handle as a single memory segment. - * In which case, it's necessary to use streaming mode to decompress data. - * note 5 : If source is untrusted, decompressed size could be wrong or intentionally modified. - * Always ensure return value fits within application's authorized limits. - * Each application can set its own limits. - * note 6 : This function replaces ZSTD_getDecompressedSize() */ + * `src` should point to the start of a ZSTD encoded frame. + * `srcSize` must be at least as large as the frame header. + * hint : any size >= `ZSTD_frameHeaderSize_max` is large enough. + * @return : - decompressed size of `src` frame content, if known + * - ZSTD_CONTENTSIZE_UNKNOWN if the size cannot be determined + * - ZSTD_CONTENTSIZE_ERROR if an error occurred (e.g. invalid magic number, srcSize too small) + * note 1 : a 0 return value means the frame is valid but "empty". + * When invoking this method on a skippable frame, it will return 0. + * note 2 : decompressed size is an optional field, it may not be present (typically in streaming mode). + * When `return==ZSTD_CONTENTSIZE_UNKNOWN`, data to decompress could be any size. + * In which case, it's necessary to use streaming mode to decompress data. + * Optionally, application can rely on some implicit limit, + * as ZSTD_decompress() only needs an upper bound of decompressed size. + * (For example, data could be necessarily cut into blocks <= 16 KB). + * note 3 : decompressed size is always present when compression is completed using single-pass functions, + * such as ZSTD_compress(), ZSTD_compressCCtx() ZSTD_compress_usingDict() or ZSTD_compress_usingCDict(). + * note 4 : decompressed size can be very large (64-bits value), + * potentially larger than what local system can handle as a single memory segment. + * In which case, it's necessary to use streaming mode to decompress data. + * note 5 : If source is untrusted, decompressed size could be wrong or intentionally modified. + * Always ensure return value fits within application's authorized limits. + * Each application can set its own limits. + * note 6 : This function replaces ZSTD_getDecompressedSize() */ #define ZSTD_CONTENTSIZE_UNKNOWN (0ULL - 1) #define ZSTD_CONTENTSIZE_ERROR (0ULL - 2) ZSTDLIB_API unsigned long long ZSTD_getFrameContentSize(const void *src, size_t srcSize); -/*! ZSTD_getDecompressedSize() : - * NOTE: This function is now obsolete, in favor of ZSTD_getFrameContentSize(). +/*! ZSTD_getDecompressedSize() (obsolete): + * This function is now obsolete, in favor of ZSTD_getFrameContentSize(). * Both functions work the same way, but ZSTD_getDecompressedSize() blends * "empty", "unknown" and "error" results to the same return value (0), * while ZSTD_getFrameContentSize() gives them separate return values. * @return : decompressed size of `src` frame content _if known and not empty_, 0 otherwise. */ +ZSTD_DEPRECATED("Replaced by ZSTD_getFrameContentSize") ZSTDLIB_API unsigned long long ZSTD_getDecompressedSize(const void* src, size_t srcSize); /*! ZSTD_findFrameCompressedSize() : Requires v1.4.0+ @@ -163,18 +197,50 @@ ZSTDLIB_API unsigned long long ZSTD_getDecompressedSize(const void* src, size_t * `srcSize` must be >= first frame size * @return : the compressed size of the first frame starting at `src`, * suitable to pass as `srcSize` to `ZSTD_decompress` or similar, - * or an error code if input is invalid */ + * or an error code if input is invalid + * Note 1: this method is called _find*() because it's not enough to read the header, + * it may have to scan through the frame's content, to reach its end. + * Note 2: this method also works with Skippable Frames. In which case, + * it returns the size of the complete skippable frame, + * which is always equal to its content size + 8 bytes for headers. */ ZSTDLIB_API size_t ZSTD_findFrameCompressedSize(const void* src, size_t srcSize); -/*====== Helper functions ======*/ -#define ZSTD_COMPRESSBOUND(srcSize) ((srcSize) + ((srcSize)>>8) + (((srcSize) < (128<<10)) ? (((128<<10) - (srcSize)) >> 11) /* margin, from 64 to 0 */ : 0)) /* this formula ensures that bound(A) + bound(B) <= bound(A+B) as long as A and B >= 128 KB */ -ZSTDLIB_API size_t ZSTD_compressBound(size_t srcSize); /*!< maximum compressed size in worst case single-pass scenario */ -ZSTDLIB_API unsigned ZSTD_isError(size_t code); /*!< tells if a `size_t` function result is an error code */ -ZSTDLIB_API const char* ZSTD_getErrorName(size_t code); /*!< provides readable string from an error code */ -ZSTDLIB_API int ZSTD_minCLevel(void); /*!< minimum negative compression level allowed, requires v1.4.0+ */ -ZSTDLIB_API int ZSTD_maxCLevel(void); /*!< maximum compression level available */ -ZSTDLIB_API int ZSTD_defaultCLevel(void); /*!< default compression level, specified by ZSTD_CLEVEL_DEFAULT, requires v1.5.0+ */ +/*====== Compression helper functions ======*/ + +/*! ZSTD_compressBound() : + * maximum compressed size in worst case single-pass scenario. + * When invoking `ZSTD_compress()`, or any other one-pass compression function, + * it's recommended to provide @dstCapacity >= ZSTD_compressBound(srcSize) + * as it eliminates one potential failure scenario, + * aka not enough room in dst buffer to write the compressed frame. + * Note : ZSTD_compressBound() itself can fail, if @srcSize >= ZSTD_MAX_INPUT_SIZE . + * In which case, ZSTD_compressBound() will return an error code + * which can be tested using ZSTD_isError(). + * + * ZSTD_COMPRESSBOUND() : + * same as ZSTD_compressBound(), but as a macro. + * It can be used to produce constants, which can be useful for static allocation, + * for example to size a static array on stack. + * Will produce constant value 0 if srcSize is too large. + */ +#define ZSTD_MAX_INPUT_SIZE ((sizeof(size_t)==8) ? 0xFF00FF00FF00FF00ULL : 0xFF00FF00U) +#define ZSTD_COMPRESSBOUND(srcSize) (((size_t)(srcSize) >= ZSTD_MAX_INPUT_SIZE) ? 0 : (srcSize) + ((srcSize)>>8) + (((srcSize) < (128<<10)) ? (((128<<10) - (srcSize)) >> 11) /* margin, from 64 to 0 */ : 0)) /* this formula ensures that bound(A) + bound(B) <= bound(A+B) as long as A and B >= 128 KB */ +ZSTDLIB_API size_t ZSTD_compressBound(size_t srcSize); /*!< maximum compressed size in worst case single-pass scenario */ + + +/*====== Error helper functions ======*/ +/* ZSTD_isError() : + * Most ZSTD_* functions returning a size_t value can be tested for error, + * using ZSTD_isError(). + * @return 1 if error, 0 otherwise + */ +ZSTDLIB_API unsigned ZSTD_isError(size_t result); /*!< tells if a `size_t` function result is an error code */ +ZSTDLIB_API ZSTD_ErrorCode ZSTD_getErrorCode(size_t functionResult); /* convert a result into an error code, which can be compared to error enum list */ +ZSTDLIB_API const char* ZSTD_getErrorName(size_t result); /*!< provides readable string from a function result */ +ZSTDLIB_API int ZSTD_minCLevel(void); /*!< minimum negative compression level allowed, requires v1.4.0+ */ +ZSTDLIB_API int ZSTD_maxCLevel(void); /*!< maximum compression level available */ +ZSTDLIB_API int ZSTD_defaultCLevel(void); /*!< default compression level, specified by ZSTD_CLEVEL_DEFAULT, requires v1.5.0+ */ /* ************************************* @@ -182,25 +248,25 @@ ZSTDLIB_API int ZSTD_defaultCLevel(void); /*!< default compres ***************************************/ /*= Compression context * When compressing many times, - * it is recommended to allocate a context just once, - * and re-use it for each successive compression operation. - * This will make workload friendlier for system's memory. + * it is recommended to allocate a compression context just once, + * and reuse it for each successive compression operation. + * This will make the workload easier for system's memory. * Note : re-using context is just a speed / resource optimization. * It doesn't change the compression ratio, which remains identical. - * Note 2 : In multi-threaded environments, - * use one different context per thread for parallel execution. + * Note 2: For parallel execution in multi-threaded environments, + * use one different context per thread . */ typedef struct ZSTD_CCtx_s ZSTD_CCtx; ZSTDLIB_API ZSTD_CCtx* ZSTD_createCCtx(void); -ZSTDLIB_API size_t ZSTD_freeCCtx(ZSTD_CCtx* cctx); /* accept NULL pointer */ +ZSTDLIB_API size_t ZSTD_freeCCtx(ZSTD_CCtx* cctx); /* compatible with NULL pointer */ /*! ZSTD_compressCCtx() : * Same as ZSTD_compress(), using an explicit ZSTD_CCtx. - * Important : in order to behave similarly to `ZSTD_compress()`, - * this function compresses at requested compression level, - * __ignoring any other parameter__ . + * Important : in order to mirror `ZSTD_compress()` behavior, + * this function compresses at the requested compression level, + * __ignoring any other advanced parameter__ . * If any advanced parameter was set using the advanced API, - * they will all be reset. Only `compressionLevel` remains. + * they will all be reset. Only @compressionLevel remains. */ ZSTDLIB_API size_t ZSTD_compressCCtx(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, @@ -210,7 +276,7 @@ ZSTDLIB_API size_t ZSTD_compressCCtx(ZSTD_CCtx* cctx, /*= Decompression context * When decompressing many times, * it is recommended to allocate a context only once, - * and re-use it for each successive compression operation. + * and reuse it for each successive compression operation. * This will make workload friendlier for system's memory. * Use one context per thread for parallel execution. */ typedef struct ZSTD_DCtx_s ZSTD_DCtx; @@ -220,7 +286,7 @@ ZSTDLIB_API size_t ZSTD_freeDCtx(ZSTD_DCtx* dctx); /* accept NULL pointer * /*! ZSTD_decompressDCtx() : * Same as ZSTD_decompress(), * requires an allocated ZSTD_DCtx. - * Compatible with sticky parameters. + * Compatible with sticky parameters (see below). */ ZSTDLIB_API size_t ZSTD_decompressDCtx(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, @@ -236,12 +302,12 @@ ZSTDLIB_API size_t ZSTD_decompressDCtx(ZSTD_DCtx* dctx, * using ZSTD_CCtx_set*() functions. * Pushed parameters are sticky : they are valid for next compressed frame, and any subsequent frame. * "sticky" parameters are applicable to `ZSTD_compress2()` and `ZSTD_compressStream*()` ! - * __They do not apply to "simple" one-shot variants such as ZSTD_compressCCtx()__ . + * __They do not apply to one-shot variants such as ZSTD_compressCCtx()__ . * * It's possible to reset all parameters to "default" using ZSTD_CCtx_reset(). * * This API supersedes all other "advanced" API entry points in the experimental section. - * In the future, we expect to remove from experimental API entry points which are redundant with this API. + * In the future, we expect to remove API entry points from experimental which are redundant with this API. */ @@ -324,6 +390,19 @@ typedef enum { * The higher the value of selected strategy, the more complex it is, * resulting in stronger and slower compression. * Special: value 0 means "use default strategy". */ + + ZSTD_c_targetCBlockSize=130, /* v1.5.6+ + * Attempts to fit compressed block size into approximately targetCBlockSize. + * Bound by ZSTD_TARGETCBLOCKSIZE_MIN and ZSTD_TARGETCBLOCKSIZE_MAX. + * Note that it's not a guarantee, just a convergence target (default:0). + * No target when targetCBlockSize == 0. + * This is helpful in low bandwidth streaming environments to improve end-to-end latency, + * when a client can make use of partial documents (a prominent example being Chrome). + * Note: this parameter is stable since v1.5.6. + * It was present as an experimental parameter in earlier versions, + * but it's not recommended using it with earlier library versions + * due to massive performance regressions. + */ /* LDM mode parameters */ ZSTD_c_enableLongDistanceMatching=160, /* Enable long distance matching. * This parameter is designed to improve compression ratio @@ -403,15 +482,18 @@ typedef enum { * ZSTD_c_forceMaxWindow * ZSTD_c_forceAttachDict * ZSTD_c_literalCompressionMode - * ZSTD_c_targetCBlockSize * ZSTD_c_srcSizeHint * ZSTD_c_enableDedicatedDictSearch * ZSTD_c_stableInBuffer * ZSTD_c_stableOutBuffer * ZSTD_c_blockDelimiters * ZSTD_c_validateSequences - * ZSTD_c_useBlockSplitter + * ZSTD_c_blockSplitterLevel + * ZSTD_c_splitAfterSequences * ZSTD_c_useRowMatchFinder + * ZSTD_c_prefetchCDictTables + * ZSTD_c_enableSeqProducerFallback + * ZSTD_c_maxBlockSize * Because they are not stable, it's necessary to define ZSTD_STATIC_LINKING_ONLY to access them. * note : never ever use experimentalParam? names directly; * also, the enums values themselves are unstable and can still change. @@ -421,7 +503,7 @@ typedef enum { ZSTD_c_experimentalParam3=1000, ZSTD_c_experimentalParam4=1001, ZSTD_c_experimentalParam5=1002, - ZSTD_c_experimentalParam6=1003, + /* was ZSTD_c_experimentalParam6=1003; is now ZSTD_c_targetCBlockSize */ ZSTD_c_experimentalParam7=1004, ZSTD_c_experimentalParam8=1005, ZSTD_c_experimentalParam9=1006, @@ -430,7 +512,12 @@ typedef enum { ZSTD_c_experimentalParam12=1009, ZSTD_c_experimentalParam13=1010, ZSTD_c_experimentalParam14=1011, - ZSTD_c_experimentalParam15=1012 + ZSTD_c_experimentalParam15=1012, + ZSTD_c_experimentalParam16=1013, + ZSTD_c_experimentalParam17=1014, + ZSTD_c_experimentalParam18=1015, + ZSTD_c_experimentalParam19=1016, + ZSTD_c_experimentalParam20=1017 } ZSTD_cParameter; typedef struct { @@ -493,7 +580,7 @@ typedef enum { * They will be used to compress next frame. * Resetting session never fails. * - The parameters : changes all parameters back to "default". - * This removes any reference to any dictionary too. + * This also removes any reference to any dictionary or external sequence producer. * Parameters can only be changed between 2 sessions (i.e. no compression is currently ongoing) * otherwise the reset fails, and function returns an error value (which can be tested using ZSTD_isError()) * - Both : similar to resetting the session, followed by resetting parameters. @@ -502,11 +589,13 @@ ZSTDLIB_API size_t ZSTD_CCtx_reset(ZSTD_CCtx* cctx, ZSTD_ResetDirective reset); /*! ZSTD_compress2() : * Behave the same as ZSTD_compressCCtx(), but compression parameters are set using the advanced API. + * (note that this entry point doesn't even expose a compression level parameter). * ZSTD_compress2() always starts a new frame. * Should cctx hold data from a previously unfinished frame, everything about it is forgotten. * - Compression parameters are pushed into CCtx before starting compression, using ZSTD_CCtx_set*() * - The function is always blocking, returns when compression is completed. - * Hint : compression runs faster if `dstCapacity` >= `ZSTD_compressBound(srcSize)`. + * NOTE: Providing `dstCapacity >= ZSTD_compressBound(srcSize)` guarantees that zstd will have + * enough space to successfully compress the data, though it is possible it fails for other reasons. * @return : compressed size written into `dst` (<= `dstCapacity), * or an error code if it fails (which can be tested using ZSTD_isError()). */ @@ -543,13 +632,17 @@ typedef enum { * ZSTD_d_stableOutBuffer * ZSTD_d_forceIgnoreChecksum * ZSTD_d_refMultipleDDicts + * ZSTD_d_disableHuffmanAssembly + * ZSTD_d_maxBlockSize * Because they are not stable, it's necessary to define ZSTD_STATIC_LINKING_ONLY to access them. * note : never ever use experimentalParam? names directly */ ZSTD_d_experimentalParam1=1000, ZSTD_d_experimentalParam2=1001, ZSTD_d_experimentalParam3=1002, - ZSTD_d_experimentalParam4=1003 + ZSTD_d_experimentalParam4=1003, + ZSTD_d_experimentalParam5=1004, + ZSTD_d_experimentalParam6=1005 } ZSTD_dParameter; @@ -604,14 +697,14 @@ typedef struct ZSTD_outBuffer_s { * A ZSTD_CStream object is required to track streaming operation. * Use ZSTD_createCStream() and ZSTD_freeCStream() to create/release resources. * ZSTD_CStream objects can be reused multiple times on consecutive compression operations. -* It is recommended to re-use ZSTD_CStream since it will play nicer with system's memory, by re-using already allocated memory. +* It is recommended to reuse ZSTD_CStream since it will play nicer with system's memory, by re-using already allocated memory. * * For parallel execution, use one separate ZSTD_CStream per thread. * * note : since v1.3.0, ZSTD_CStream and ZSTD_CCtx are the same thing. * * Parameters are sticky : when starting a new compression on the same context, -* it will re-use the same sticky parameters as previous compression session. +* it will reuse the same sticky parameters as previous compression session. * When in doubt, it's recommended to fully initialize the context before usage. * Use ZSTD_CCtx_reset() to reset the context and ZSTD_CCtx_setParameter(), * ZSTD_CCtx_setPledgedSrcSize(), or ZSTD_CCtx_loadDictionary() and friends to @@ -700,6 +793,11 @@ typedef enum { * only ZSTD_e_end or ZSTD_e_flush operations are allowed. * Before starting a new compression job, or changing compression parameters, * it is required to fully flush internal buffers. + * - note: if an operation ends with an error, it may leave @cctx in an undefined state. + * Therefore, it's UB to invoke ZSTD_compressStream2() of ZSTD_compressStream() on such a state. + * In order to be re-employed after an error, a state must be reset, + * which can be done explicitly (ZSTD_CCtx_reset()), + * or is sometimes implied by methods starting a new compression job (ZSTD_initCStream(), ZSTD_compressCCtx()) */ ZSTDLIB_API size_t ZSTD_compressStream2( ZSTD_CCtx* cctx, ZSTD_outBuffer* output, @@ -728,8 +826,6 @@ ZSTDLIB_API size_t ZSTD_CStreamOutSize(void); /*< recommended size for output * This following is a legacy streaming API, available since v1.0+ . * It can be replaced by ZSTD_CCtx_reset() and ZSTD_compressStream2(). * It is redundant, but remains fully supported. - * Streaming in combination with advanced parameters and dictionary compression - * can only be used through the new API. ******************************************************************************/ /*! @@ -738,6 +834,9 @@ ZSTDLIB_API size_t ZSTD_CStreamOutSize(void); /*< recommended size for output * ZSTD_CCtx_reset(zcs, ZSTD_reset_session_only); * ZSTD_CCtx_refCDict(zcs, NULL); // clear the dictionary (if any) * ZSTD_CCtx_setParameter(zcs, ZSTD_c_compressionLevel, compressionLevel); + * + * Note that ZSTD_initCStream() clears any previously set dictionary. Use the new API + * to compress with a dictionary. */ ZSTDLIB_API size_t ZSTD_initCStream(ZSTD_CStream* zcs, int compressionLevel); /*! @@ -758,7 +857,7 @@ ZSTDLIB_API size_t ZSTD_endStream(ZSTD_CStream* zcs, ZSTD_outBuffer* output); * * A ZSTD_DStream object is required to track streaming operations. * Use ZSTD_createDStream() and ZSTD_freeDStream() to create/release resources. -* ZSTD_DStream objects can be re-used multiple times. +* ZSTD_DStream objects can be re-employed multiple times. * * Use ZSTD_initDStream() to start a new decompression operation. * @return : recommended first input size @@ -768,16 +867,21 @@ ZSTDLIB_API size_t ZSTD_endStream(ZSTD_CStream* zcs, ZSTD_outBuffer* output); * The function will update both `pos` fields. * If `input.pos < input.size`, some input has not been consumed. * It's up to the caller to present again remaining data. +* * The function tries to flush all data decoded immediately, respecting output buffer size. * If `output.pos < output.size`, decoder has flushed everything it could. -* But if `output.pos == output.size`, there might be some data left within internal buffers., +* +* However, when `output.pos == output.size`, it's more difficult to know. +* If @return > 0, the frame is not complete, meaning +* either there is still some data left to flush within internal buffers, +* or there is more input to read to complete the frame (or both). * In which case, call ZSTD_decompressStream() again to flush whatever remains in the buffer. * Note : with no additional input provided, amount of data flushed is necessarily <= ZSTD_BLOCKSIZE_MAX. * @return : 0 when a frame is completely decoded and fully flushed, * or an error code, which can be tested using ZSTD_isError(), * or any other value > 0, which means there is still some decoding or flushing to do to complete current frame : * the return value is a suggested next input size (just a hint for better latency) -* that will never request more than the remaining frame size. +* that will never request more than the remaining content of the compressed frame. * *******************************************************************************/ typedef ZSTD_DCtx ZSTD_DStream; /*< DCtx and DStream are now effectively same object (>= v1.3.0) */ @@ -788,13 +892,38 @@ ZSTDLIB_API size_t ZSTD_freeDStream(ZSTD_DStream* zds); /* accept NULL pointer /*===== Streaming decompression functions =====*/ -/* This function is redundant with the advanced API and equivalent to: +/*! ZSTD_initDStream() : + * Initialize/reset DStream state for new decompression operation. + * Call before new decompression operation using same DStream. * + * Note : This function is redundant with the advanced API and equivalent to: * ZSTD_DCtx_reset(zds, ZSTD_reset_session_only); * ZSTD_DCtx_refDDict(zds, NULL); */ ZSTDLIB_API size_t ZSTD_initDStream(ZSTD_DStream* zds); +/*! ZSTD_decompressStream() : + * Streaming decompression function. + * Call repetitively to consume full input updating it as necessary. + * Function will update both input and output `pos` fields exposing current state via these fields: + * - `input.pos < input.size`, some input remaining and caller should provide remaining input + * on the next call. + * - `output.pos < output.size`, decoder flushed internal output buffer. + * - `output.pos == output.size`, unflushed data potentially present in the internal buffers, + * check ZSTD_decompressStream() @return value, + * if > 0, invoke it again to flush remaining data to output. + * Note : with no additional input, amount of data flushed <= ZSTD_BLOCKSIZE_MAX. + * + * @return : 0 when a frame is completely decoded and fully flushed, + * or an error code, which can be tested using ZSTD_isError(), + * or any other value > 0, which means there is some decoding or flushing to do to complete current frame. + * + * Note: when an operation returns with an error code, the @zds state may be left in undefined state. + * It's UB to invoke `ZSTD_decompressStream()` on such a state. + * In order to re-use such a state, it must be first reset, + * which can be done explicitly (`ZSTD_DCtx_reset()`), + * or is implied for operations starting some new decompression job (`ZSTD_initDStream`, `ZSTD_decompressDCtx()`, `ZSTD_decompress_usingDict()`) + */ ZSTDLIB_API size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inBuffer* input); ZSTDLIB_API size_t ZSTD_DStreamInSize(void); /*!< recommended size for input buffer */ @@ -913,7 +1042,7 @@ ZSTDLIB_API unsigned ZSTD_getDictID_fromDDict(const ZSTD_DDict* ddict); * If @return == 0, the dictID could not be decoded. * This could for one of the following reasons : * - The frame does not require a dictionary to be decoded (most common case). - * - The frame was built with dictID intentionally removed. Whatever dictionary is necessary is a hidden information. + * - The frame was built with dictID intentionally removed. Whatever dictionary is necessary is a hidden piece of information. * Note : this use case also happens when using a non-conformant dictionary. * - `srcSize` is too small, and as a result, the frame header could not be decoded (only possible if `srcSize < ZSTD_FRAMEHEADERSIZE_MAX`). * - This is not a Zstandard frame. @@ -925,9 +1054,11 @@ ZSTDLIB_API unsigned ZSTD_getDictID_fromFrame(const void* src, size_t srcSize); * Advanced dictionary and prefix API (Requires v1.4.0+) * * This API allows dictionaries to be used with ZSTD_compress2(), - * ZSTD_compressStream2(), and ZSTD_decompressDCtx(). Dictionaries are sticky, and - * only reset with the context is reset with ZSTD_reset_parameters or - * ZSTD_reset_session_and_parameters. Prefixes are single-use. + * ZSTD_compressStream2(), and ZSTD_decompressDCtx(). + * Dictionaries are sticky, they remain valid when same context is reused, + * they only reset when the context is reset + * with ZSTD_reset_parameters or ZSTD_reset_session_and_parameters. + * In contrast, Prefixes are single-use. ******************************************************************************/ @@ -937,8 +1068,9 @@ ZSTDLIB_API unsigned ZSTD_getDictID_fromFrame(const void* src, size_t srcSize); * @result : 0, or an error code (which can be tested with ZSTD_isError()). * Special: Loading a NULL (or 0-size) dictionary invalidates previous dictionary, * meaning "return to no-dictionary mode". - * Note 1 : Dictionary is sticky, it will be used for all future compressed frames. - * To return to "no-dictionary" situation, load a NULL dictionary (or reset parameters). + * Note 1 : Dictionary is sticky, it will be used for all future compressed frames, + * until parameters are reset, a new dictionary is loaded, or the dictionary + * is explicitly invalidated by loading a NULL dictionary. * Note 2 : Loading a dictionary involves building tables. * It's also a CPU consuming operation, with non-negligible impact on latency. * Tables are dependent on compression parameters, and for this reason, @@ -947,11 +1079,15 @@ ZSTDLIB_API unsigned ZSTD_getDictID_fromFrame(const void* src, size_t srcSize); * Use experimental ZSTD_CCtx_loadDictionary_byReference() to reference content instead. * In such a case, dictionary buffer must outlive its users. * Note 4 : Use ZSTD_CCtx_loadDictionary_advanced() - * to precisely select how dictionary content must be interpreted. */ + * to precisely select how dictionary content must be interpreted. + * Note 5 : This method does not benefit from LDM (long distance mode). + * If you want to employ LDM on some large dictionary content, + * prefer employing ZSTD_CCtx_refPrefix() described below. + */ ZSTDLIB_API size_t ZSTD_CCtx_loadDictionary(ZSTD_CCtx* cctx, const void* dict, size_t dictSize); /*! ZSTD_CCtx_refCDict() : Requires v1.4.0+ - * Reference a prepared dictionary, to be used for all next compressed frames. + * Reference a prepared dictionary, to be used for all future compressed frames. * Note that compression parameters are enforced from within CDict, * and supersede any compression parameter previously set within CCtx. * The parameters ignored are labelled as "superseded-by-cdict" in the ZSTD_cParameter enum docs. @@ -970,6 +1106,7 @@ ZSTDLIB_API size_t ZSTD_CCtx_refCDict(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict); * Decompression will need same prefix to properly regenerate data. * Compressing with a prefix is similar in outcome as performing a diff and compressing it, * but performs much faster, especially during decompression (compression speed is tunable with compression level). + * This method is compatible with LDM (long distance mode). * @result : 0, or an error code (which can be tested with ZSTD_isError()). * Special: Adding any prefix (including NULL) invalidates any previous prefix or dictionary * Note 1 : Prefix buffer is referenced. It **must** outlive compression. @@ -986,9 +1123,9 @@ ZSTDLIB_API size_t ZSTD_CCtx_refPrefix(ZSTD_CCtx* cctx, const void* prefix, size_t prefixSize); /*! ZSTD_DCtx_loadDictionary() : Requires v1.4.0+ - * Create an internal DDict from dict buffer, - * to be used to decompress next frames. - * The dictionary remains valid for all future frames, until explicitly invalidated. + * Create an internal DDict from dict buffer, to be used to decompress all future frames. + * The dictionary remains valid for all future frames, until explicitly invalidated, or + * a new dictionary is loaded. * @result : 0, or an error code (which can be tested with ZSTD_isError()). * Special : Adding a NULL (or 0-size) dictionary invalidates any previous dictionary, * meaning "return to no-dictionary mode". @@ -1012,9 +1149,10 @@ ZSTDLIB_API size_t ZSTD_DCtx_loadDictionary(ZSTD_DCtx* dctx, const void* dict, s * The memory for the table is allocated on the first call to refDDict, and can be * freed with ZSTD_freeDCtx(). * + * If called with ZSTD_d_refMultipleDDicts disabled (the default), only one dictionary + * will be managed, and referencing a dictionary effectively "discards" any previous one. + * * @result : 0, or an error code (which can be tested with ZSTD_isError()). - * Note 1 : Currently, only one dictionary can be managed. - * Referencing a new dictionary effectively "discards" any previous one. * Special: referencing a NULL DDict means "return to no-dictionary mode". * Note 2 : DDict is just referenced, its lifetime must outlive its usage from DCtx. */ @@ -1051,6 +1189,7 @@ ZSTDLIB_API size_t ZSTD_sizeof_DStream(const ZSTD_DStream* zds); ZSTDLIB_API size_t ZSTD_sizeof_CDict(const ZSTD_CDict* cdict); ZSTDLIB_API size_t ZSTD_sizeof_DDict(const ZSTD_DDict* ddict); + #endif /* ZSTD_H_235446 */ @@ -1066,29 +1205,12 @@ ZSTDLIB_API size_t ZSTD_sizeof_DDict(const ZSTD_DDict* ddict); #if !defined(ZSTD_H_ZSTD_STATIC_LINKING_ONLY) #define ZSTD_H_ZSTD_STATIC_LINKING_ONLY + /* This can be overridden externally to hide static symbols. */ #ifndef ZSTDLIB_STATIC_API #define ZSTDLIB_STATIC_API ZSTDLIB_VISIBLE #endif -/* Deprecation warnings : - * Should these warnings be a problem, it is generally possible to disable them, - * typically with -Wno-deprecated-declarations for gcc or _CRT_SECURE_NO_WARNINGS in Visual. - * Otherwise, it's also possible to define ZSTD_DISABLE_DEPRECATE_WARNINGS. - */ -#ifdef ZSTD_DISABLE_DEPRECATE_WARNINGS -# define ZSTD_DEPRECATED(message) ZSTDLIB_STATIC_API /* disable deprecation warnings */ -#else -# if (defined(GNUC) && (GNUC > 4 || (GNUC == 4 && GNUC_MINOR >= 5))) || defined(__clang__) -# define ZSTD_DEPRECATED(message) ZSTDLIB_STATIC_API __attribute__((deprecated(message))) -# elif (__GNUC__ >= 3) -# define ZSTD_DEPRECATED(message) ZSTDLIB_STATIC_API __attribute__((deprecated)) -# else -# pragma message("WARNING: You need to implement ZSTD_DEPRECATED for this compiler") -# define ZSTD_DEPRECATED(message) ZSTDLIB_STATIC_API -# endif -#endif /* ZSTD_DISABLE_DEPRECATE_WARNINGS */ - /* ************************************************************************************** * experimental API (static linking only) **************************************************************************************** @@ -1123,6 +1245,7 @@ ZSTDLIB_API size_t ZSTD_sizeof_DDict(const ZSTD_DDict* ddict); #define ZSTD_TARGETLENGTH_MIN 0 /* note : comparing this constant to an unsigned results in a tautological test */ #define ZSTD_STRATEGY_MIN ZSTD_fast #define ZSTD_STRATEGY_MAX ZSTD_btultra2 +#define ZSTD_BLOCKSIZE_MAX_MIN (1 << 10) /* The minimum valid max blocksize. Maximum blocksizes smaller than this make compressBound() inaccurate. */ #define ZSTD_OVERLAPLOG_MIN 0 @@ -1146,7 +1269,7 @@ ZSTDLIB_API size_t ZSTD_sizeof_DDict(const ZSTD_DDict* ddict); #define ZSTD_LDM_HASHRATELOG_MAX (ZSTD_WINDOWLOG_MAX - ZSTD_HASHLOG_MIN) /* Advanced parameter bounds */ -#define ZSTD_TARGETCBLOCKSIZE_MIN 64 +#define ZSTD_TARGETCBLOCKSIZE_MIN 1340 /* suitable to fit into an ethernet / wifi / 4G transport frame */ #define ZSTD_TARGETCBLOCKSIZE_MAX ZSTD_BLOCKSIZE_MAX #define ZSTD_SRCSIZEHINT_MIN 0 #define ZSTD_SRCSIZEHINT_MAX INT_MAX @@ -1188,7 +1311,7 @@ typedef struct { * * Note: This field is optional. ZSTD_generateSequences() will calculate the value of * 'rep', but repeat offsets do not necessarily need to be calculated from an external - * sequence provider's perspective. For example, ZSTD_compressSequences() does not + * sequence provider perspective. For example, ZSTD_compressSequences() does not * use this 'rep' field at all (as of now). */ } ZSTD_Sequence; @@ -1293,17 +1416,18 @@ typedef enum { } ZSTD_literalCompressionMode_e; typedef enum { - /* Note: This enum controls features which are conditionally beneficial. Zstd typically will make a final - * decision on whether or not to enable the feature (ZSTD_ps_auto), but setting the switch to ZSTD_ps_enable - * or ZSTD_ps_disable allow for a force enable/disable the feature. + /* Note: This enum controls features which are conditionally beneficial. + * Zstd can take a decision on whether or not to enable the feature (ZSTD_ps_auto), + * but setting the switch to ZSTD_ps_enable or ZSTD_ps_disable force enable/disable the feature. */ ZSTD_ps_auto = 0, /* Let the library automatically determine whether the feature shall be enabled */ ZSTD_ps_enable = 1, /* Force-enable the feature */ ZSTD_ps_disable = 2 /* Do not use the feature */ -} ZSTD_paramSwitch_e; +} ZSTD_ParamSwitch_e; +#define ZSTD_paramSwitch_e ZSTD_ParamSwitch_e /* old name */ /* ************************************* -* Frame size functions +* Frame header and size functions ***************************************/ /*! ZSTD_findDecompressedSize() : @@ -1345,34 +1469,130 @@ ZSTDLIB_STATIC_API unsigned long long ZSTD_findDecompressedSize(const void* src, ZSTDLIB_STATIC_API unsigned long long ZSTD_decompressBound(const void* src, size_t srcSize); /*! ZSTD_frameHeaderSize() : - * srcSize must be >= ZSTD_FRAMEHEADERSIZE_PREFIX. + * srcSize must be large enough, aka >= ZSTD_FRAMEHEADERSIZE_PREFIX. * @return : size of the Frame Header, * or an error code (if srcSize is too small) */ ZSTDLIB_STATIC_API size_t ZSTD_frameHeaderSize(const void* src, size_t srcSize); +typedef enum { ZSTD_frame, ZSTD_skippableFrame } ZSTD_FrameType_e; +#define ZSTD_frameType_e ZSTD_FrameType_e /* old name */ +typedef struct { + unsigned long long frameContentSize; /* if == ZSTD_CONTENTSIZE_UNKNOWN, it means this field is not available. 0 means "empty" */ + unsigned long long windowSize; /* can be very large, up to <= frameContentSize */ + unsigned blockSizeMax; + ZSTD_FrameType_e frameType; /* if == ZSTD_skippableFrame, frameContentSize is the size of skippable content */ + unsigned headerSize; + unsigned dictID; /* for ZSTD_skippableFrame, contains the skippable magic variant [0-15] */ + unsigned checksumFlag; + unsigned _reserved1; + unsigned _reserved2; +} ZSTD_FrameHeader; +#define ZSTD_frameHeader ZSTD_FrameHeader /* old name */ + +/*! ZSTD_getFrameHeader() : + * decode Frame Header into `zfhPtr`, or requires larger `srcSize`. + * @return : 0 => header is complete, `zfhPtr` is correctly filled, + * >0 => `srcSize` is too small, @return value is the wanted `srcSize` amount, `zfhPtr` is not filled, + * or an error code, which can be tested using ZSTD_isError() */ +ZSTDLIB_STATIC_API size_t ZSTD_getFrameHeader(ZSTD_FrameHeader* zfhPtr, const void* src, size_t srcSize); +/*! ZSTD_getFrameHeader_advanced() : + * same as ZSTD_getFrameHeader(), + * with added capability to select a format (like ZSTD_f_zstd1_magicless) */ +ZSTDLIB_STATIC_API size_t ZSTD_getFrameHeader_advanced(ZSTD_FrameHeader* zfhPtr, const void* src, size_t srcSize, ZSTD_format_e format); + +/*! ZSTD_decompressionMargin() : + * Zstd supports in-place decompression, where the input and output buffers overlap. + * In this case, the output buffer must be at least (Margin + Output_Size) bytes large, + * and the input buffer must be at the end of the output buffer. + * + * _______________________ Output Buffer ________________________ + * | | + * | ____ Input Buffer ____| + * | | | + * v v v + * |---------------------------------------|-----------|----------| + * ^ ^ ^ + * |___________________ Output_Size ___________________|_ Margin _| + * + * NOTE: See also ZSTD_DECOMPRESSION_MARGIN(). + * NOTE: This applies only to single-pass decompression through ZSTD_decompress() or + * ZSTD_decompressDCtx(). + * NOTE: This function supports multi-frame input. + * + * @param src The compressed frame(s) + * @param srcSize The size of the compressed frame(s) + * @returns The decompression margin or an error that can be checked with ZSTD_isError(). + */ +ZSTDLIB_STATIC_API size_t ZSTD_decompressionMargin(const void* src, size_t srcSize); + +/*! ZSTD_DECOMPRESS_MARGIN() : + * Similar to ZSTD_decompressionMargin(), but instead of computing the margin from + * the compressed frame, compute it from the original size and the blockSizeLog. + * See ZSTD_decompressionMargin() for details. + * + * WARNING: This macro does not support multi-frame input, the input must be a single + * zstd frame. If you need that support use the function, or implement it yourself. + * + * @param originalSize The original uncompressed size of the data. + * @param blockSize The block size == MIN(windowSize, ZSTD_BLOCKSIZE_MAX). + * Unless you explicitly set the windowLog smaller than + * ZSTD_BLOCKSIZELOG_MAX you can just use ZSTD_BLOCKSIZE_MAX. + */ +#define ZSTD_DECOMPRESSION_MARGIN(originalSize, blockSize) ((size_t)( \ + ZSTD_FRAMEHEADERSIZE_MAX /* Frame header */ + \ + 4 /* checksum */ + \ + ((originalSize) == 0 ? 0 : 3 * (((originalSize) + (blockSize) - 1) / blockSize)) /* 3 bytes per block */ + \ + (blockSize) /* One block of margin */ \ + )) + typedef enum { - ZSTD_sf_noBlockDelimiters = 0, /* Representation of ZSTD_Sequence has no block delimiters, sequences only */ - ZSTD_sf_explicitBlockDelimiters = 1 /* Representation of ZSTD_Sequence contains explicit block delimiters */ -} ZSTD_sequenceFormat_e; + ZSTD_sf_noBlockDelimiters = 0, /* ZSTD_Sequence[] has no block delimiters, just sequences */ + ZSTD_sf_explicitBlockDelimiters = 1 /* ZSTD_Sequence[] contains explicit block delimiters */ +} ZSTD_SequenceFormat_e; +#define ZSTD_sequenceFormat_e ZSTD_SequenceFormat_e /* old name */ + +/*! ZSTD_sequenceBound() : + * `srcSize` : size of the input buffer + * @return : upper-bound for the number of sequences that can be generated + * from a buffer of srcSize bytes + * + * note : returns number of sequences - to get bytes, multiply by sizeof(ZSTD_Sequence). + */ +ZSTDLIB_STATIC_API size_t ZSTD_sequenceBound(size_t srcSize); /*! ZSTD_generateSequences() : - * Generate sequences using ZSTD_compress2, given a source buffer. + * WARNING: This function is meant for debugging and informational purposes ONLY! + * Its implementation is flawed, and it will be deleted in a future version. + * It is not guaranteed to succeed, as there are several cases where it will give + * up and fail. You should NOT use this function in production code. + * + * This function is deprecated, and will be removed in a future version. + * + * Generate sequences using ZSTD_compress2(), given a source buffer. + * + * @param zc The compression context to be used for ZSTD_compress2(). Set any + * compression parameters you need on this context. + * @param outSeqs The output sequences buffer of size @p outSeqsSize + * @param outSeqsCapacity The size of the output sequences buffer. + * ZSTD_sequenceBound(srcSize) is an upper bound on the number + * of sequences that can be generated. + * @param src The source buffer to generate sequences from of size @p srcSize. + * @param srcSize The size of the source buffer. * * Each block will end with a dummy sequence * with offset == 0, matchLength == 0, and litLength == length of last literals. * litLength may be == 0, and if so, then the sequence of (of: 0 ml: 0 ll: 0) * simply acts as a block delimiter. * - * zc can be used to insert custom compression params. - * This function invokes ZSTD_compress2 - * - * The output of this function can be fed into ZSTD_compressSequences() with CCtx - * setting of ZSTD_c_blockDelimiters as ZSTD_sf_explicitBlockDelimiters - * @return : number of sequences generated + * @returns The number of sequences generated, necessarily less than + * ZSTD_sequenceBound(srcSize), or an error code that can be checked + * with ZSTD_isError(). */ - -ZSTDLIB_STATIC_API size_t ZSTD_generateSequences(ZSTD_CCtx* zc, ZSTD_Sequence* outSeqs, - size_t outSeqsSize, const void* src, size_t srcSize); +ZSTD_DEPRECATED("For debugging only, will be replaced by ZSTD_extractSequences()") +ZSTDLIB_STATIC_API size_t +ZSTD_generateSequences(ZSTD_CCtx* zc, + ZSTD_Sequence* outSeqs, size_t outSeqsCapacity, + const void* src, size_t srcSize); /*! ZSTD_mergeBlockDelimiters() : * Given an array of ZSTD_Sequence, remove all sequences that represent block delimiters/last literals @@ -1388,8 +1608,10 @@ ZSTDLIB_STATIC_API size_t ZSTD_generateSequences(ZSTD_CCtx* zc, ZSTD_Sequence* o ZSTDLIB_STATIC_API size_t ZSTD_mergeBlockDelimiters(ZSTD_Sequence* sequences, size_t seqsSize); /*! ZSTD_compressSequences() : - * Compress an array of ZSTD_Sequence, generated from the original source buffer, into dst. - * If a dictionary is included, then the cctx should reference the dict. (see: ZSTD_CCtx_refCDict(), ZSTD_CCtx_loadDictionary(), etc.) + * Compress an array of ZSTD_Sequence, associated with @src buffer, into dst. + * @src contains the entire input (not just the literals). + * If @srcSize > sum(sequence.length), the remaining bytes are considered all literals + * If a dictionary is included, then the cctx should reference the dict (see: ZSTD_CCtx_refCDict(), ZSTD_CCtx_loadDictionary(), etc.). * The entire source is compressed into a single frame. * * The compression behavior changes based on cctx params. In particular: @@ -1398,11 +1620,17 @@ ZSTDLIB_STATIC_API size_t ZSTD_mergeBlockDelimiters(ZSTD_Sequence* sequences, si * the block size derived from the cctx, and sequences may be split. This is the default setting. * * If ZSTD_c_blockDelimiters == ZSTD_sf_explicitBlockDelimiters, the array of ZSTD_Sequence is expected to contain - * block delimiters (defined in ZSTD_Sequence). Behavior is undefined if no block delimiters are provided. + * valid block delimiters (defined in ZSTD_Sequence). Behavior is undefined if no block delimiters are provided. + * + * When ZSTD_c_blockDelimiters == ZSTD_sf_explicitBlockDelimiters, it's possible to decide generating repcodes + * using the advanced parameter ZSTD_c_repcodeResolution. Repcodes will improve compression ratio, though the benefit + * can vary greatly depending on Sequences. On the other hand, repcode resolution is an expensive operation. + * By default, it's disabled at low (<10) compression levels, and enabled above the threshold (>=10). + * ZSTD_c_repcodeResolution makes it possible to directly manage this processing in either direction. * - * If ZSTD_c_validateSequences == 0, this function will blindly accept the sequences provided. Invalid sequences cause undefined - * behavior. If ZSTD_c_validateSequences == 1, then if sequence is invalid (see doc/zstd_compression_format.md for - * specifics regarding offset/matchlength requirements) then the function will bail out and return an error. + * If ZSTD_c_validateSequences == 0, this function blindly accepts the Sequences provided. Invalid Sequences cause undefined + * behavior. If ZSTD_c_validateSequences == 1, then the function will detect invalid Sequences (see doc/zstd_compression_format.md for + * specifics regarding offset/matchlength requirements) and then bail out and return an error. * * In addition to the two adjustable experimental params, there are other important cctx params. * - ZSTD_c_minMatch MUST be set as less than or equal to the smallest match generated by the match finder. It has a minimum value of ZSTD_MINMATCH_MIN. @@ -1410,14 +1638,42 @@ ZSTDLIB_STATIC_API size_t ZSTD_mergeBlockDelimiters(ZSTD_Sequence* sequences, si * - ZSTD_c_windowLog affects offset validation: this function will return an error at higher debug levels if a provided offset * is larger than what the spec allows for a given window log and dictionary (if present). See: doc/zstd_compression_format.md * - * Note: Repcodes are, as of now, always re-calculated within this function, so ZSTD_Sequence::rep is unused. - * Note 2: Once we integrate ability to ingest repcodes, the explicit block delims mode must respect those repcodes exactly, - * and cannot emit an RLE block that disagrees with the repcode history - * @return : final compressed size or a ZSTD error. - */ -ZSTDLIB_STATIC_API size_t ZSTD_compressSequences(ZSTD_CCtx* const cctx, void* dst, size_t dstSize, - const ZSTD_Sequence* inSeqs, size_t inSeqsSize, - const void* src, size_t srcSize); + * Note: Repcodes are, as of now, always re-calculated within this function, ZSTD_Sequence.rep is effectively unused. + * Dev Note: Once ability to ingest repcodes become available, the explicit block delims mode must respect those repcodes exactly, + * and cannot emit an RLE block that disagrees with the repcode history. + * @return : final compressed size, or a ZSTD error code. + */ +ZSTDLIB_STATIC_API size_t +ZSTD_compressSequences(ZSTD_CCtx* cctx, + void* dst, size_t dstCapacity, + const ZSTD_Sequence* inSeqs, size_t inSeqsSize, + const void* src, size_t srcSize); + + +/*! ZSTD_compressSequencesAndLiterals() : + * This is a variant of ZSTD_compressSequences() which, + * instead of receiving (src,srcSize) as input parameter, receives (literals,litSize), + * aka all the literals, already extracted and laid out into a single continuous buffer. + * This can be useful if the process generating the sequences also happens to generate the buffer of literals, + * thus skipping an extraction + caching stage. + * It's a speed optimization, useful when the right conditions are met, + * but it also features the following limitations: + * - Only supports explicit delimiter mode + * - Currently does not support Sequences validation (so input Sequences are trusted) + * - Not compatible with frame checksum, which must be disabled + * - If any block is incompressible, will fail and return an error + * - @litSize must be == sum of all @.litLength fields in @inSeqs. Any discrepancy will generate an error. + * - @litBufCapacity is the size of the underlying buffer into which literals are written, starting at address @literals. + * @litBufCapacity must be at least 8 bytes larger than @litSize. + * - @decompressedSize must be correct, and correspond to the sum of all Sequences. Any discrepancy will generate an error. + * @return : final compressed size, or a ZSTD error code. + */ +ZSTDLIB_STATIC_API size_t +ZSTD_compressSequencesAndLiterals(ZSTD_CCtx* cctx, + void* dst, size_t dstCapacity, + const ZSTD_Sequence* inSeqs, size_t nbSequences, + const void* literals, size_t litSize, size_t litBufCapacity, + size_t decompressedSize); /*! ZSTD_writeSkippableFrame() : @@ -1425,8 +1681,8 @@ ZSTDLIB_STATIC_API size_t ZSTD_compressSequences(ZSTD_CCtx* const cctx, void* ds * * Skippable frames begin with a 4-byte magic number. There are 16 possible choices of magic number, * ranging from ZSTD_MAGIC_SKIPPABLE_START to ZSTD_MAGIC_SKIPPABLE_START+15. - * As such, the parameter magicVariant controls the exact skippable frame magic number variant used, so - * the magic number used will be ZSTD_MAGIC_SKIPPABLE_START + magicVariant. + * As such, the parameter magicVariant controls the exact skippable frame magic number variant used, + * so the magic number used will be ZSTD_MAGIC_SKIPPABLE_START + magicVariant. * * Returns an error if destination buffer is not large enough, if the source size is not representable * with a 4-byte unsigned int, or if the parameter magicVariant is greater than 15 (and therefore invalid). @@ -1434,26 +1690,28 @@ ZSTDLIB_STATIC_API size_t ZSTD_compressSequences(ZSTD_CCtx* const cctx, void* ds * @return : number of bytes written or a ZSTD error. */ ZSTDLIB_STATIC_API size_t ZSTD_writeSkippableFrame(void* dst, size_t dstCapacity, - const void* src, size_t srcSize, unsigned magicVariant); + const void* src, size_t srcSize, + unsigned magicVariant); /*! ZSTD_readSkippableFrame() : - * Retrieves a zstd skippable frame containing data given by src, and writes it to dst buffer. + * Retrieves the content of a zstd skippable frame starting at @src, and writes it to @dst buffer. * - * The parameter magicVariant will receive the magicVariant that was supplied when the frame was written, - * i.e. magicNumber - ZSTD_MAGIC_SKIPPABLE_START. This can be NULL if the caller is not interested - * in the magicVariant. + * The parameter @magicVariant will receive the magicVariant that was supplied when the frame was written, + * i.e. magicNumber - ZSTD_MAGIC_SKIPPABLE_START. + * This can be NULL if the caller is not interested in the magicVariant. * * Returns an error if destination buffer is not large enough, or if the frame is not skippable. * * @return : number of bytes written or a ZSTD error. */ -ZSTDLIB_API size_t ZSTD_readSkippableFrame(void* dst, size_t dstCapacity, unsigned* magicVariant, - const void* src, size_t srcSize); +ZSTDLIB_STATIC_API size_t ZSTD_readSkippableFrame(void* dst, size_t dstCapacity, + unsigned* magicVariant, + const void* src, size_t srcSize); /*! ZSTD_isSkippableFrame() : * Tells if the content of `buffer` starts with a valid Frame Identifier for a skippable frame. */ -ZSTDLIB_API unsigned ZSTD_isSkippableFrame(const void* buffer, size_t size); +ZSTDLIB_STATIC_API unsigned ZSTD_isSkippableFrame(const void* buffer, size_t size); @@ -1464,48 +1722,59 @@ ZSTDLIB_API unsigned ZSTD_isSkippableFrame(const void* buffer, size_t size); /*! ZSTD_estimate*() : * These functions make it possible to estimate memory usage * of a future {D,C}Ctx, before its creation. + * This is useful in combination with ZSTD_initStatic(), + * which makes it possible to employ a static buffer for ZSTD_CCtx* state. * * ZSTD_estimateCCtxSize() will provide a memory budget large enough - * for any compression level up to selected one. - * Note : Unlike ZSTD_estimateCStreamSize*(), this estimate - * does not include space for a window buffer. - * Therefore, the estimation is only guaranteed for single-shot compressions, not streaming. + * to compress data of any size using one-shot compression ZSTD_compressCCtx() or ZSTD_compress2() + * associated with any compression level up to max specified one. * The estimate will assume the input may be arbitrarily large, * which is the worst case. * + * Note that the size estimation is specific for one-shot compression, + * it is not valid for streaming (see ZSTD_estimateCStreamSize*()) + * nor other potential ways of using a ZSTD_CCtx* state. + * * When srcSize can be bound by a known and rather "small" value, - * this fact can be used to provide a tighter estimation - * because the CCtx compression context will need less memory. - * This tighter estimation can be provided by more advanced functions + * this knowledge can be used to provide a tighter budget estimation + * because the ZSTD_CCtx* state will need less memory for small inputs. + * This tighter estimation can be provided by employing more advanced functions * ZSTD_estimateCCtxSize_usingCParams(), which can be used in tandem with ZSTD_getCParams(), * and ZSTD_estimateCCtxSize_usingCCtxParams(), which can be used in tandem with ZSTD_CCtxParams_setParameter(). * Both can be used to estimate memory using custom compression parameters and arbitrary srcSize limits. * - * Note 2 : only single-threaded compression is supported. + * Note : only single-threaded compression is supported. * ZSTD_estimateCCtxSize_usingCCtxParams() will return an error code if ZSTD_c_nbWorkers is >= 1. */ -ZSTDLIB_STATIC_API size_t ZSTD_estimateCCtxSize(int compressionLevel); +ZSTDLIB_STATIC_API size_t ZSTD_estimateCCtxSize(int maxCompressionLevel); ZSTDLIB_STATIC_API size_t ZSTD_estimateCCtxSize_usingCParams(ZSTD_compressionParameters cParams); ZSTDLIB_STATIC_API size_t ZSTD_estimateCCtxSize_usingCCtxParams(const ZSTD_CCtx_params* params); ZSTDLIB_STATIC_API size_t ZSTD_estimateDCtxSize(void); /*! ZSTD_estimateCStreamSize() : - * ZSTD_estimateCStreamSize() will provide a budget large enough for any compression level up to selected one. - * It will also consider src size to be arbitrarily "large", which is worst case. + * ZSTD_estimateCStreamSize() will provide a memory budget large enough for streaming compression + * using any compression level up to the max specified one. + * It will also consider src size to be arbitrarily "large", which is a worst case scenario. * If srcSize is known to always be small, ZSTD_estimateCStreamSize_usingCParams() can provide a tighter estimation. * ZSTD_estimateCStreamSize_usingCParams() can be used in tandem with ZSTD_getCParams() to create cParams from compressionLevel. * ZSTD_estimateCStreamSize_usingCCtxParams() can be used in tandem with ZSTD_CCtxParams_setParameter(). Only single-threaded compression is supported. This function will return an error code if ZSTD_c_nbWorkers is >= 1. * Note : CStream size estimation is only correct for single-threaded compression. - * ZSTD_DStream memory budget depends on window Size. + * ZSTD_estimateCStreamSize_usingCCtxParams() will return an error code if ZSTD_c_nbWorkers is >= 1. + * Note 2 : ZSTD_estimateCStreamSize* functions are not compatible with the Block-Level Sequence Producer API at this time. + * Size estimates assume that no external sequence producer is registered. + * + * ZSTD_DStream memory budget depends on frame's window Size. * This information can be passed manually, using ZSTD_estimateDStreamSize, * or deducted from a valid frame Header, using ZSTD_estimateDStreamSize_fromFrame(); + * Any frame requesting a window size larger than max specified one will be rejected. * Note : if streaming is init with function ZSTD_init?Stream_usingDict(), * an internal ?Dict will be created, which additional size is not estimated here. - * In this case, get total size by adding ZSTD_estimate?DictSize */ -ZSTDLIB_STATIC_API size_t ZSTD_estimateCStreamSize(int compressionLevel); + * In this case, get total size by adding ZSTD_estimate?DictSize + */ +ZSTDLIB_STATIC_API size_t ZSTD_estimateCStreamSize(int maxCompressionLevel); ZSTDLIB_STATIC_API size_t ZSTD_estimateCStreamSize_usingCParams(ZSTD_compressionParameters cParams); ZSTDLIB_STATIC_API size_t ZSTD_estimateCStreamSize_usingCCtxParams(const ZSTD_CCtx_params* params); -ZSTDLIB_STATIC_API size_t ZSTD_estimateDStreamSize(size_t windowSize); +ZSTDLIB_STATIC_API size_t ZSTD_estimateDStreamSize(size_t maxWindowSize); ZSTDLIB_STATIC_API size_t ZSTD_estimateDStreamSize_fromFrame(const void* src, size_t srcSize); /*! ZSTD_estimate?DictSize() : @@ -1568,7 +1837,15 @@ typedef void (*ZSTD_freeFunction) (void* opaque, void* address); typedef struct { ZSTD_allocFunction customAlloc; ZSTD_freeFunction customFree; void* opaque; } ZSTD_customMem; static __attribute__((__unused__)) + +#if defined(__clang__) && __clang_major__ >= 5 +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wzero-as-null-pointer-constant" +#endif ZSTD_customMem const ZSTD_defaultCMem = { NULL, NULL, NULL }; /*< this constant defers to stdlib's functions */ +#if defined(__clang__) && __clang_major__ >= 5 +#pragma clang diagnostic pop +#endif ZSTDLIB_STATIC_API ZSTD_CCtx* ZSTD_createCCtx_advanced(ZSTD_customMem customMem); ZSTDLIB_STATIC_API ZSTD_CStream* ZSTD_createCStream_advanced(ZSTD_customMem customMem); @@ -1649,22 +1926,45 @@ ZSTDLIB_STATIC_API size_t ZSTD_checkCParams(ZSTD_compressionParameters params); * This function never fails (wide contract) */ ZSTDLIB_STATIC_API ZSTD_compressionParameters ZSTD_adjustCParams(ZSTD_compressionParameters cPar, unsigned long long srcSize, size_t dictSize); +/*! ZSTD_CCtx_setCParams() : + * Set all parameters provided within @p cparams into the working @p cctx. + * Note : if modifying parameters during compression (MT mode only), + * note that changes to the .windowLog parameter will be ignored. + * @return 0 on success, or an error code (can be checked with ZSTD_isError()). + * On failure, no parameters are updated. + */ +ZSTDLIB_STATIC_API size_t ZSTD_CCtx_setCParams(ZSTD_CCtx* cctx, ZSTD_compressionParameters cparams); + +/*! ZSTD_CCtx_setFParams() : + * Set all parameters provided within @p fparams into the working @p cctx. + * @return 0 on success, or an error code (can be checked with ZSTD_isError()). + */ +ZSTDLIB_STATIC_API size_t ZSTD_CCtx_setFParams(ZSTD_CCtx* cctx, ZSTD_frameParameters fparams); + +/*! ZSTD_CCtx_setParams() : + * Set all parameters provided within @p params into the working @p cctx. + * @return 0 on success, or an error code (can be checked with ZSTD_isError()). + */ +ZSTDLIB_STATIC_API size_t ZSTD_CCtx_setParams(ZSTD_CCtx* cctx, ZSTD_parameters params); + /*! ZSTD_compress_advanced() : * Note : this function is now DEPRECATED. * It can be replaced by ZSTD_compress2(), in combination with ZSTD_CCtx_setParameter() and other parameter setters. * This prototype will generate compilation warnings. */ ZSTD_DEPRECATED("use ZSTD_compress2") +ZSTDLIB_STATIC_API size_t ZSTD_compress_advanced(ZSTD_CCtx* cctx, - void* dst, size_t dstCapacity, - const void* src, size_t srcSize, - const void* dict,size_t dictSize, - ZSTD_parameters params); + void* dst, size_t dstCapacity, + const void* src, size_t srcSize, + const void* dict,size_t dictSize, + ZSTD_parameters params); /*! ZSTD_compress_usingCDict_advanced() : * Note : this function is now DEPRECATED. * It can be replaced by ZSTD_compress2(), in combination with ZSTD_CCtx_loadDictionary() and other parameter setters. * This prototype will generate compilation warnings. */ ZSTD_DEPRECATED("use ZSTD_compress2 with ZSTD_CCtx_loadDictionary") +ZSTDLIB_STATIC_API size_t ZSTD_compress_usingCDict_advanced(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize, @@ -1725,7 +2025,7 @@ ZSTDLIB_STATIC_API size_t ZSTD_CCtx_refPrefix_advanced(ZSTD_CCtx* cctx, const vo * See the comments on that enum for an explanation of the feature. */ #define ZSTD_c_forceAttachDict ZSTD_c_experimentalParam4 -/* Controlled with ZSTD_paramSwitch_e enum. +/* Controlled with ZSTD_ParamSwitch_e enum. * Default is ZSTD_ps_auto. * Set to ZSTD_ps_disable to never compress literals. * Set to ZSTD_ps_enable to always compress literals. (Note: uncompressed literals @@ -1737,11 +2037,6 @@ ZSTDLIB_STATIC_API size_t ZSTD_CCtx_refPrefix_advanced(ZSTD_CCtx* cctx, const vo */ #define ZSTD_c_literalCompressionMode ZSTD_c_experimentalParam5 -/* Tries to fit compressed block size to be around targetCBlockSize. - * No target when targetCBlockSize == 0. - * There is no guarantee on compressed block size (default:0) */ -#define ZSTD_c_targetCBlockSize ZSTD_c_experimentalParam6 - /* User's best guess of source size. * Hint is not valid when srcSizeHint == 0. * There is no guarantee that hint is close to actual source size, @@ -1808,13 +2103,16 @@ ZSTDLIB_STATIC_API size_t ZSTD_CCtx_refPrefix_advanced(ZSTD_CCtx* cctx, const vo * Experimental parameter. * Default is 0 == disabled. Set to 1 to enable. * - * Tells the compressor that the ZSTD_inBuffer will ALWAYS be the same - * between calls, except for the modifications that zstd makes to pos (the - * caller must not modify pos). This is checked by the compressor, and - * compression will fail if it ever changes. This means the only flush - * mode that makes sense is ZSTD_e_end, so zstd will error if ZSTD_e_end - * is not used. The data in the ZSTD_inBuffer in the range [src, src + pos) - * MUST not be modified during compression or you will get data corruption. + * Tells the compressor that input data presented with ZSTD_inBuffer + * will ALWAYS be the same between calls. + * Technically, the @src pointer must never be changed, + * and the @pos field can only be updated by zstd. + * However, it's possible to increase the @size field, + * allowing scenarios where more data can be appended after compressions starts. + * These conditions are checked by the compressor, + * and compression will fail if they are not respected. + * Also, data in the ZSTD_inBuffer within the range [src, src + pos) + * MUST not be modified during compression or it will result in data corruption. * * When this flag is enabled zstd won't allocate an input window buffer, * because the user guarantees it can reference the ZSTD_inBuffer until @@ -1822,18 +2120,15 @@ ZSTDLIB_STATIC_API size_t ZSTD_CCtx_refPrefix_advanced(ZSTD_CCtx* cctx, const vo * large enough to fit a block (see ZSTD_c_stableOutBuffer). This will also * avoid the memcpy() from the input buffer to the input window buffer. * - * NOTE: ZSTD_compressStream2() will error if ZSTD_e_end is not used. - * That means this flag cannot be used with ZSTD_compressStream(). - * * NOTE: So long as the ZSTD_inBuffer always points to valid memory, using * this flag is ALWAYS memory safe, and will never access out-of-bounds - * memory. However, compression WILL fail if you violate the preconditions. + * memory. However, compression WILL fail if conditions are not respected. * - * WARNING: The data in the ZSTD_inBuffer in the range [dst, dst + pos) MUST - * not be modified during compression or you will get data corruption. This - * is because zstd needs to reference data in the ZSTD_inBuffer to find + * WARNING: The data in the ZSTD_inBuffer in the range [src, src + pos) MUST + * not be modified during compression or it will result in data corruption. + * This is because zstd needs to reference data in the ZSTD_inBuffer to find * matches. Normally zstd maintains its own window buffer for this purpose, - * but passing this flag tells zstd to use the user provided buffer. + * but passing this flag tells zstd to rely on user provided buffer instead. */ #define ZSTD_c_stableInBuffer ZSTD_c_experimentalParam9 @@ -1871,22 +2166,46 @@ ZSTDLIB_STATIC_API size_t ZSTD_CCtx_refPrefix_advanced(ZSTD_CCtx* cctx, const vo /* ZSTD_c_validateSequences * Default is 0 == disabled. Set to 1 to enable sequence validation. * - * For use with sequence compression API: ZSTD_compressSequences(). - * Designates whether or not we validate sequences provided to ZSTD_compressSequences() + * For use with sequence compression API: ZSTD_compressSequences*(). + * Designates whether or not provided sequences are validated within ZSTD_compressSequences*() * during function execution. * - * Without validation, providing a sequence that does not conform to the zstd spec will cause - * undefined behavior, and may produce a corrupted block. + * When Sequence validation is disabled (default), Sequences are compressed as-is, + * so they must correct, otherwise it would result in a corruption error. * - * With validation enabled, a if sequence is invalid (see doc/zstd_compression_format.md for + * Sequence validation adds some protection, by ensuring that all values respect boundary conditions. + * If a Sequence is detected invalid (see doc/zstd_compression_format.md for * specifics regarding offset/matchlength requirements) then the function will bail out and * return an error. - * */ #define ZSTD_c_validateSequences ZSTD_c_experimentalParam12 -/* ZSTD_c_useBlockSplitter - * Controlled with ZSTD_paramSwitch_e enum. +/* ZSTD_c_blockSplitterLevel + * note: this parameter only influences the first splitter stage, + * which is active before producing the sequences. + * ZSTD_c_splitAfterSequences controls the next splitter stage, + * which is active after sequence production. + * Note that both can be combined. + * Allowed values are between 0 and ZSTD_BLOCKSPLITTER_LEVEL_MAX included. + * 0 means "auto", which will select a value depending on current ZSTD_c_strategy. + * 1 means no splitting. + * Then, values from 2 to 6 are sorted in increasing cpu load order. + * + * Note that currently the first block is never split, + * to ensure expansion guarantees in presence of incompressible data. + */ +#define ZSTD_BLOCKSPLITTER_LEVEL_MAX 6 +#define ZSTD_c_blockSplitterLevel ZSTD_c_experimentalParam20 + +/* ZSTD_c_splitAfterSequences + * This is a stronger splitter algorithm, + * based on actual sequences previously produced by the selected parser. + * It's also slower, and as a consequence, mostly used for high compression levels. + * While the post-splitter does overlap with the pre-splitter, + * both can nonetheless be combined, + * notably with ZSTD_c_blockSplitterLevel at ZSTD_BLOCKSPLITTER_LEVEL_MAX, + * resulting in higher compression ratio than just one of them. + * * Default is ZSTD_ps_auto. * Set to ZSTD_ps_disable to never use block splitter. * Set to ZSTD_ps_enable to always use block splitter. @@ -1894,10 +2213,10 @@ ZSTDLIB_STATIC_API size_t ZSTD_CCtx_refPrefix_advanced(ZSTD_CCtx* cctx, const vo * By default, in ZSTD_ps_auto, the library will decide at runtime whether to use * block splitting based on the compression parameters. */ -#define ZSTD_c_useBlockSplitter ZSTD_c_experimentalParam13 +#define ZSTD_c_splitAfterSequences ZSTD_c_experimentalParam13 /* ZSTD_c_useRowMatchFinder - * Controlled with ZSTD_paramSwitch_e enum. + * Controlled with ZSTD_ParamSwitch_e enum. * Default is ZSTD_ps_auto. * Set to ZSTD_ps_disable to never use row-based matchfinder. * Set to ZSTD_ps_enable to force usage of row-based matchfinder. @@ -1928,6 +2247,80 @@ ZSTDLIB_STATIC_API size_t ZSTD_CCtx_refPrefix_advanced(ZSTD_CCtx* cctx, const vo */ #define ZSTD_c_deterministicRefPrefix ZSTD_c_experimentalParam15 +/* ZSTD_c_prefetchCDictTables + * Controlled with ZSTD_ParamSwitch_e enum. Default is ZSTD_ps_auto. + * + * In some situations, zstd uses CDict tables in-place rather than copying them + * into the working context. (See docs on ZSTD_dictAttachPref_e above for details). + * In such situations, compression speed is seriously impacted when CDict tables are + * "cold" (outside CPU cache). This parameter instructs zstd to prefetch CDict tables + * when they are used in-place. + * + * For sufficiently small inputs, the cost of the prefetch will outweigh the benefit. + * For sufficiently large inputs, zstd will by default memcpy() CDict tables + * into the working context, so there is no need to prefetch. This parameter is + * targeted at a middle range of input sizes, where a prefetch is cheap enough to be + * useful but memcpy() is too expensive. The exact range of input sizes where this + * makes sense is best determined by careful experimentation. + * + * Note: for this parameter, ZSTD_ps_auto is currently equivalent to ZSTD_ps_disable, + * but in the future zstd may conditionally enable this feature via an auto-detection + * heuristic for cold CDicts. + * Use ZSTD_ps_disable to opt out of prefetching under any circumstances. + */ +#define ZSTD_c_prefetchCDictTables ZSTD_c_experimentalParam16 + +/* ZSTD_c_enableSeqProducerFallback + * Allowed values are 0 (disable) and 1 (enable). The default setting is 0. + * + * Controls whether zstd will fall back to an internal sequence producer if an + * external sequence producer is registered and returns an error code. This fallback + * is block-by-block: the internal sequence producer will only be called for blocks + * where the external sequence producer returns an error code. Fallback parsing will + * follow any other cParam settings, such as compression level, the same as in a + * normal (fully-internal) compression operation. + * + * The user is strongly encouraged to read the full Block-Level Sequence Producer API + * documentation (below) before setting this parameter. */ +#define ZSTD_c_enableSeqProducerFallback ZSTD_c_experimentalParam17 + +/* ZSTD_c_maxBlockSize + * Allowed values are between 1KB and ZSTD_BLOCKSIZE_MAX (128KB). + * The default is ZSTD_BLOCKSIZE_MAX, and setting to 0 will set to the default. + * + * This parameter can be used to set an upper bound on the blocksize + * that overrides the default ZSTD_BLOCKSIZE_MAX. It cannot be used to set upper + * bounds greater than ZSTD_BLOCKSIZE_MAX or bounds lower than 1KB (will make + * compressBound() inaccurate). Only currently meant to be used for testing. + */ +#define ZSTD_c_maxBlockSize ZSTD_c_experimentalParam18 + +/* ZSTD_c_repcodeResolution + * This parameter only has an effect if ZSTD_c_blockDelimiters is + * set to ZSTD_sf_explicitBlockDelimiters (may change in the future). + * + * This parameter affects how zstd parses external sequences, + * provided via the ZSTD_compressSequences*() API + * or from an external block-level sequence producer. + * + * If set to ZSTD_ps_enable, the library will check for repeated offsets within + * external sequences, even if those repcodes are not explicitly indicated in + * the "rep" field. Note that this is the only way to exploit repcode matches + * while using compressSequences*() or an external sequence producer, since zstd + * currently ignores the "rep" field of external sequences. + * + * If set to ZSTD_ps_disable, the library will not exploit repeated offsets in + * external sequences, regardless of whether the "rep" field has been set. This + * reduces sequence compression overhead by about 25% while sacrificing some + * compression ratio. + * + * The default value is ZSTD_ps_auto, for which the library will enable/disable + * based on compression level (currently: level<10 disables, level>=10 enables). + */ +#define ZSTD_c_repcodeResolution ZSTD_c_experimentalParam19 +#define ZSTD_c_searchForExternalRepcodes ZSTD_c_experimentalParam19 /* older name */ + + /*! ZSTD_CCtx_getParameter() : * Get the requested compression parameter value, selected by enum ZSTD_cParameter, * and store it into int* value. @@ -2084,7 +2477,7 @@ ZSTDLIB_STATIC_API size_t ZSTD_DCtx_getParameter(ZSTD_DCtx* dctx, ZSTD_dParamete * in the range [dst, dst + pos) MUST not be modified during decompression * or you will get data corruption. * - * When this flags is enabled zstd won't allocate an output buffer, because + * When this flag is enabled zstd won't allocate an output buffer, because * it can write directly to the ZSTD_outBuffer, but it will still allocate * an input buffer large enough to fit any compressed block. This will also * avoid the memcpy() from the internal output buffer to the ZSTD_outBuffer. @@ -2137,6 +2530,33 @@ ZSTDLIB_STATIC_API size_t ZSTD_DCtx_getParameter(ZSTD_DCtx* dctx, ZSTD_dParamete */ #define ZSTD_d_refMultipleDDicts ZSTD_d_experimentalParam4 +/* ZSTD_d_disableHuffmanAssembly + * Set to 1 to disable the Huffman assembly implementation. + * The default value is 0, which allows zstd to use the Huffman assembly + * implementation if available. + * + * This parameter can be used to disable Huffman assembly at runtime. + * If you want to disable it at compile time you can define the macro + * ZSTD_DISABLE_ASM. + */ +#define ZSTD_d_disableHuffmanAssembly ZSTD_d_experimentalParam5 + +/* ZSTD_d_maxBlockSize + * Allowed values are between 1KB and ZSTD_BLOCKSIZE_MAX (128KB). + * The default is ZSTD_BLOCKSIZE_MAX, and setting to 0 will set to the default. + * + * Forces the decompressor to reject blocks whose content size is + * larger than the configured maxBlockSize. When maxBlockSize is + * larger than the windowSize, the windowSize is used instead. + * This saves memory on the decoder when you know all blocks are small. + * + * This option is typically used in conjunction with ZSTD_c_maxBlockSize. + * + * WARNING: This causes the decoder to reject otherwise valid frames + * that have block sizes larger than the configured maxBlockSize. + */ +#define ZSTD_d_maxBlockSize ZSTD_d_experimentalParam6 + /*! ZSTD_DCtx_setFormat() : * This function is REDUNDANT. Prefer ZSTD_DCtx_setParameter(). @@ -2145,6 +2565,7 @@ ZSTDLIB_STATIC_API size_t ZSTD_DCtx_getParameter(ZSTD_DCtx* dctx, ZSTD_dParamete * such ZSTD_f_zstd1_magicless for example. * @return : 0, or an error code (which can be tested using ZSTD_isError()). */ ZSTD_DEPRECATED("use ZSTD_DCtx_setParameter() instead") +ZSTDLIB_STATIC_API size_t ZSTD_DCtx_setFormat(ZSTD_DCtx* dctx, ZSTD_format_e format); /*! ZSTD_decompressStream_simpleArgs() : @@ -2181,6 +2602,7 @@ ZSTDLIB_STATIC_API size_t ZSTD_decompressStream_simpleArgs ( * This prototype will generate compilation warnings. */ ZSTD_DEPRECATED("use ZSTD_CCtx_reset, see zstd.h for detailed instructions") +ZSTDLIB_STATIC_API size_t ZSTD_initCStream_srcSize(ZSTD_CStream* zcs, int compressionLevel, unsigned long long pledgedSrcSize); @@ -2198,17 +2620,15 @@ size_t ZSTD_initCStream_srcSize(ZSTD_CStream* zcs, * This prototype will generate compilation warnings. */ ZSTD_DEPRECATED("use ZSTD_CCtx_reset, see zstd.h for detailed instructions") +ZSTDLIB_STATIC_API size_t ZSTD_initCStream_usingDict(ZSTD_CStream* zcs, const void* dict, size_t dictSize, int compressionLevel); /*! ZSTD_initCStream_advanced() : - * This function is DEPRECATED, and is approximately equivalent to: + * This function is DEPRECATED, and is equivalent to: * ZSTD_CCtx_reset(zcs, ZSTD_reset_session_only); - * // Pseudocode: Set each zstd parameter and leave the rest as-is. - * for ((param, value) : params) { - * ZSTD_CCtx_setParameter(zcs, param, value); - * } + * ZSTD_CCtx_setParams(zcs, params); * ZSTD_CCtx_setPledgedSrcSize(zcs, pledgedSrcSize); * ZSTD_CCtx_loadDictionary(zcs, dict, dictSize); * @@ -2218,6 +2638,7 @@ size_t ZSTD_initCStream_usingDict(ZSTD_CStream* zcs, * This prototype will generate compilation warnings. */ ZSTD_DEPRECATED("use ZSTD_CCtx_reset, see zstd.h for detailed instructions") +ZSTDLIB_STATIC_API size_t ZSTD_initCStream_advanced(ZSTD_CStream* zcs, const void* dict, size_t dictSize, ZSTD_parameters params, @@ -2232,15 +2653,13 @@ size_t ZSTD_initCStream_advanced(ZSTD_CStream* zcs, * This prototype will generate compilation warnings. */ ZSTD_DEPRECATED("use ZSTD_CCtx_reset and ZSTD_CCtx_refCDict, see zstd.h for detailed instructions") +ZSTDLIB_STATIC_API size_t ZSTD_initCStream_usingCDict(ZSTD_CStream* zcs, const ZSTD_CDict* cdict); /*! ZSTD_initCStream_usingCDict_advanced() : - * This function is DEPRECATED, and is approximately equivalent to: + * This function is DEPRECATED, and is equivalent to: * ZSTD_CCtx_reset(zcs, ZSTD_reset_session_only); - * // Pseudocode: Set each zstd frame parameter and leave the rest as-is. - * for ((fParam, value) : fParams) { - * ZSTD_CCtx_setParameter(zcs, fParam, value); - * } + * ZSTD_CCtx_setFParams(zcs, fParams); * ZSTD_CCtx_setPledgedSrcSize(zcs, pledgedSrcSize); * ZSTD_CCtx_refCDict(zcs, cdict); * @@ -2250,6 +2669,7 @@ size_t ZSTD_initCStream_usingCDict(ZSTD_CStream* zcs, const ZSTD_CDict* cdict); * This prototype will generate compilation warnings. */ ZSTD_DEPRECATED("use ZSTD_CCtx_reset and ZSTD_CCtx_refCDict, see zstd.h for detailed instructions") +ZSTDLIB_STATIC_API size_t ZSTD_initCStream_usingCDict_advanced(ZSTD_CStream* zcs, const ZSTD_CDict* cdict, ZSTD_frameParameters fParams, @@ -2264,7 +2684,7 @@ size_t ZSTD_initCStream_usingCDict_advanced(ZSTD_CStream* zcs, * explicitly specified. * * start a new frame, using same parameters from previous frame. - * This is typically useful to skip dictionary loading stage, since it will re-use it in-place. + * This is typically useful to skip dictionary loading stage, since it will reuse it in-place. * Note that zcs must be init at least once before using ZSTD_resetCStream(). * If pledgedSrcSize is not known at reset time, use macro ZSTD_CONTENTSIZE_UNKNOWN. * If pledgedSrcSize > 0, its value must be correct, as it will be written in header, and controlled at the end. @@ -2274,6 +2694,7 @@ size_t ZSTD_initCStream_usingCDict_advanced(ZSTD_CStream* zcs, * This prototype will generate compilation warnings. */ ZSTD_DEPRECATED("use ZSTD_CCtx_reset, see zstd.h for detailed instructions") +ZSTDLIB_STATIC_API size_t ZSTD_resetCStream(ZSTD_CStream* zcs, unsigned long long pledgedSrcSize); @@ -2319,8 +2740,8 @@ ZSTDLIB_STATIC_API size_t ZSTD_toFlushNow(ZSTD_CCtx* cctx); * ZSTD_DCtx_loadDictionary(zds, dict, dictSize); * * note: no dictionary will be used if dict == NULL or dictSize < 8 - * Note : this prototype will be marked as deprecated and generate compilation warnings on reaching v1.5.x */ +ZSTD_DEPRECATED("use ZSTD_DCtx_reset + ZSTD_DCtx_loadDictionary, see zstd.h for detailed instructions") ZSTDLIB_STATIC_API size_t ZSTD_initDStream_usingDict(ZSTD_DStream* zds, const void* dict, size_t dictSize); /*! @@ -2330,8 +2751,8 @@ ZSTDLIB_STATIC_API size_t ZSTD_initDStream_usingDict(ZSTD_DStream* zds, const vo * ZSTD_DCtx_refDDict(zds, ddict); * * note : ddict is referenced, it must outlive decompression session - * Note : this prototype will be marked as deprecated and generate compilation warnings on reaching v1.5.x */ +ZSTD_DEPRECATED("use ZSTD_DCtx_reset + ZSTD_DCtx_refDDict, see zstd.h for detailed instructions") ZSTDLIB_STATIC_API size_t ZSTD_initDStream_usingDDict(ZSTD_DStream* zds, const ZSTD_DDict* ddict); /*! @@ -2339,18 +2760,202 @@ ZSTDLIB_STATIC_API size_t ZSTD_initDStream_usingDDict(ZSTD_DStream* zds, const Z * * ZSTD_DCtx_reset(zds, ZSTD_reset_session_only); * - * re-use decompression parameters from previous init; saves dictionary loading - * Note : this prototype will be marked as deprecated and generate compilation warnings on reaching v1.5.x + * reuse decompression parameters from previous init; saves dictionary loading */ +ZSTD_DEPRECATED("use ZSTD_DCtx_reset, see zstd.h for detailed instructions") ZSTDLIB_STATIC_API size_t ZSTD_resetDStream(ZSTD_DStream* zds); +/* ********************* BLOCK-LEVEL SEQUENCE PRODUCER API ********************* + * + * *** OVERVIEW *** + * The Block-Level Sequence Producer API allows users to provide their own custom + * sequence producer which libzstd invokes to process each block. The produced list + * of sequences (literals and matches) is then post-processed by libzstd to produce + * valid compressed blocks. + * + * This block-level offload API is a more granular complement of the existing + * frame-level offload API compressSequences() (introduced in v1.5.1). It offers + * an easier migration story for applications already integrated with libzstd: the + * user application continues to invoke the same compression functions + * ZSTD_compress2() or ZSTD_compressStream2() as usual, and transparently benefits + * from the specific advantages of the external sequence producer. For example, + * the sequence producer could be tuned to take advantage of known characteristics + * of the input, to offer better speed / ratio, or could leverage hardware + * acceleration not available within libzstd itself. + * + * See contrib/externalSequenceProducer for an example program employing the + * Block-Level Sequence Producer API. + * + * *** USAGE *** + * The user is responsible for implementing a function of type + * ZSTD_sequenceProducer_F. For each block, zstd will pass the following + * arguments to the user-provided function: + * + * - sequenceProducerState: a pointer to a user-managed state for the sequence + * producer. + * + * - outSeqs, outSeqsCapacity: an output buffer for the sequence producer. + * outSeqsCapacity is guaranteed >= ZSTD_sequenceBound(srcSize). The memory + * backing outSeqs is managed by the CCtx. + * + * - src, srcSize: an input buffer for the sequence producer to parse. + * srcSize is guaranteed to be <= ZSTD_BLOCKSIZE_MAX. + * + * - dict, dictSize: a history buffer, which may be empty, which the sequence + * producer may reference as it parses the src buffer. Currently, zstd will + * always pass dictSize == 0 into external sequence producers, but this will + * change in the future. + * + * - compressionLevel: a signed integer representing the zstd compression level + * set by the user for the current operation. The sequence producer may choose + * to use this information to change its compression strategy and speed/ratio + * tradeoff. Note: the compression level does not reflect zstd parameters set + * through the advanced API. + * + * - windowSize: a size_t representing the maximum allowed offset for external + * sequences. Note that sequence offsets are sometimes allowed to exceed the + * windowSize if a dictionary is present, see doc/zstd_compression_format.md + * for details. + * + * The user-provided function shall return a size_t representing the number of + * sequences written to outSeqs. This return value will be treated as an error + * code if it is greater than outSeqsCapacity. The return value must be non-zero + * if srcSize is non-zero. The ZSTD_SEQUENCE_PRODUCER_ERROR macro is provided + * for convenience, but any value greater than outSeqsCapacity will be treated as + * an error code. + * + * If the user-provided function does not return an error code, the sequences + * written to outSeqs must be a valid parse of the src buffer. Data corruption may + * occur if the parse is not valid. A parse is defined to be valid if the + * following conditions hold: + * - The sum of matchLengths and literalLengths must equal srcSize. + * - All sequences in the parse, except for the final sequence, must have + * matchLength >= ZSTD_MINMATCH_MIN. The final sequence must have + * matchLength >= ZSTD_MINMATCH_MIN or matchLength == 0. + * - All offsets must respect the windowSize parameter as specified in + * doc/zstd_compression_format.md. + * - If the final sequence has matchLength == 0, it must also have offset == 0. + * + * zstd will only validate these conditions (and fail compression if they do not + * hold) if the ZSTD_c_validateSequences cParam is enabled. Note that sequence + * validation has a performance cost. + * + * If the user-provided function returns an error, zstd will either fall back + * to an internal sequence producer or fail the compression operation. The user can + * choose between the two behaviors by setting the ZSTD_c_enableSeqProducerFallback + * cParam. Fallback compression will follow any other cParam settings, such as + * compression level, the same as in a normal compression operation. + * + * The user shall instruct zstd to use a particular ZSTD_sequenceProducer_F + * function by calling + * ZSTD_registerSequenceProducer(cctx, + * sequenceProducerState, + * sequenceProducer) + * This setting will persist until the next parameter reset of the CCtx. + * + * The sequenceProducerState must be initialized by the user before calling + * ZSTD_registerSequenceProducer(). The user is responsible for destroying the + * sequenceProducerState. + * + * *** LIMITATIONS *** + * This API is compatible with all zstd compression APIs which respect advanced parameters. + * However, there are three limitations: + * + * First, the ZSTD_c_enableLongDistanceMatching cParam is not currently supported. + * COMPRESSION WILL FAIL if it is enabled and the user tries to compress with a block-level + * external sequence producer. + * - Note that ZSTD_c_enableLongDistanceMatching is auto-enabled by default in some + * cases (see its documentation for details). Users must explicitly set + * ZSTD_c_enableLongDistanceMatching to ZSTD_ps_disable in such cases if an external + * sequence producer is registered. + * - As of this writing, ZSTD_c_enableLongDistanceMatching is disabled by default + * whenever ZSTD_c_windowLog < 128MB, but that's subject to change. Users should + * check the docs on ZSTD_c_enableLongDistanceMatching whenever the Block-Level Sequence + * Producer API is used in conjunction with advanced settings (like ZSTD_c_windowLog). + * + * Second, history buffers are not currently supported. Concretely, zstd will always pass + * dictSize == 0 to the external sequence producer (for now). This has two implications: + * - Dictionaries are not currently supported. Compression will *not* fail if the user + * references a dictionary, but the dictionary won't have any effect. + * - Stream history is not currently supported. All advanced compression APIs, including + * streaming APIs, work with external sequence producers, but each block is treated as + * an independent chunk without history from previous blocks. + * + * Third, multi-threading within a single compression is not currently supported. In other words, + * COMPRESSION WILL FAIL if ZSTD_c_nbWorkers > 0 and an external sequence producer is registered. + * Multi-threading across compressions is fine: simply create one CCtx per thread. + * + * Long-term, we plan to overcome all three limitations. There is no technical blocker to + * overcoming them. It is purely a question of engineering effort. + */ + +#define ZSTD_SEQUENCE_PRODUCER_ERROR ((size_t)(-1)) + +typedef size_t (*ZSTD_sequenceProducer_F) ( + void* sequenceProducerState, + ZSTD_Sequence* outSeqs, size_t outSeqsCapacity, + const void* src, size_t srcSize, + const void* dict, size_t dictSize, + int compressionLevel, + size_t windowSize +); + +/*! ZSTD_registerSequenceProducer() : + * Instruct zstd to use a block-level external sequence producer function. + * + * The sequenceProducerState must be initialized by the caller, and the caller is + * responsible for managing its lifetime. This parameter is sticky across + * compressions. It will remain set until the user explicitly resets compression + * parameters. + * + * Sequence producer registration is considered to be an "advanced parameter", + * part of the "advanced API". This means it will only have an effect on compression + * APIs which respect advanced parameters, such as compress2() and compressStream2(). + * Older compression APIs such as compressCCtx(), which predate the introduction of + * "advanced parameters", will ignore any external sequence producer setting. + * + * The sequence producer can be "cleared" by registering a NULL function pointer. This + * removes all limitations described above in the "LIMITATIONS" section of the API docs. + * + * The user is strongly encouraged to read the full API documentation (above) before + * calling this function. */ +ZSTDLIB_STATIC_API void +ZSTD_registerSequenceProducer( + ZSTD_CCtx* cctx, + void* sequenceProducerState, + ZSTD_sequenceProducer_F sequenceProducer +); + +/*! ZSTD_CCtxParams_registerSequenceProducer() : + * Same as ZSTD_registerSequenceProducer(), but operates on ZSTD_CCtx_params. + * This is used for accurate size estimation with ZSTD_estimateCCtxSize_usingCCtxParams(), + * which is needed when creating a ZSTD_CCtx with ZSTD_initStaticCCtx(). + * + * If you are using the external sequence producer API in a scenario where ZSTD_initStaticCCtx() + * is required, then this function is for you. Otherwise, you probably don't need it. + * + * See tests/zstreamtest.c for example usage. */ +ZSTDLIB_STATIC_API void +ZSTD_CCtxParams_registerSequenceProducer( + ZSTD_CCtx_params* params, + void* sequenceProducerState, + ZSTD_sequenceProducer_F sequenceProducer +); + + /* ******************************************************************* -* Buffer-less and synchronous inner streaming functions +* Buffer-less and synchronous inner streaming functions (DEPRECATED) +* +* This API is deprecated, and will be removed in a future version. +* It allows streaming (de)compression with user allocated buffers. +* However, it is hard to use, and not as well tested as the rest of +* our API. * -* This is an advanced API, giving full control over buffer management, for users which need direct control over memory. -* But it's also a complex one, with several restrictions, documented below. -* Prefer normal streaming API for an easier experience. +* Please use the normal streaming API instead: ZSTD_compressStream2, +* and ZSTD_decompressStream. +* If there is functionality that you need, but it doesn't provide, +* please open an issue on our GitHub. ********************************************************************* */ /* @@ -2358,11 +2963,10 @@ ZSTDLIB_STATIC_API size_t ZSTD_resetDStream(ZSTD_DStream* zds); A ZSTD_CCtx object is required to track streaming operations. Use ZSTD_createCCtx() / ZSTD_freeCCtx() to manage resource. - ZSTD_CCtx object can be re-used multiple times within successive compression operations. + ZSTD_CCtx object can be reused multiple times within successive compression operations. Start by initializing a context. Use ZSTD_compressBegin(), or ZSTD_compressBegin_usingDict() for dictionary compression. - It's also possible to duplicate a reference context which has already been initialized, using ZSTD_copyCCtx() Then, consume your input using ZSTD_compressContinue(). There are some important considerations to keep in mind when using this advanced function : @@ -2380,39 +2984,49 @@ ZSTDLIB_STATIC_API size_t ZSTD_resetDStream(ZSTD_DStream* zds); It's possible to use srcSize==0, in which case, it will write a final empty block to end the frame. Without last block mark, frames are considered unfinished (hence corrupted) by compliant decoders. - `ZSTD_CCtx` object can be re-used (ZSTD_compressBegin()) to compress again. + `ZSTD_CCtx` object can be reused (ZSTD_compressBegin()) to compress again. */ /*===== Buffer-less streaming compression functions =====*/ +ZSTD_DEPRECATED("The buffer-less API is deprecated in favor of the normal streaming API. See docs.") ZSTDLIB_STATIC_API size_t ZSTD_compressBegin(ZSTD_CCtx* cctx, int compressionLevel); +ZSTD_DEPRECATED("The buffer-less API is deprecated in favor of the normal streaming API. See docs.") ZSTDLIB_STATIC_API size_t ZSTD_compressBegin_usingDict(ZSTD_CCtx* cctx, const void* dict, size_t dictSize, int compressionLevel); +ZSTD_DEPRECATED("The buffer-less API is deprecated in favor of the normal streaming API. See docs.") ZSTDLIB_STATIC_API size_t ZSTD_compressBegin_usingCDict(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict); /*< note: fails if cdict==NULL */ -ZSTDLIB_STATIC_API size_t ZSTD_copyCCtx(ZSTD_CCtx* cctx, const ZSTD_CCtx* preparedCCtx, unsigned long long pledgedSrcSize); /*< note: if pledgedSrcSize is not known, use ZSTD_CONTENTSIZE_UNKNOWN */ +ZSTD_DEPRECATED("This function will likely be removed in a future release. It is misleading and has very limited utility.") +ZSTDLIB_STATIC_API +size_t ZSTD_copyCCtx(ZSTD_CCtx* cctx, const ZSTD_CCtx* preparedCCtx, unsigned long long pledgedSrcSize); /*< note: if pledgedSrcSize is not known, use ZSTD_CONTENTSIZE_UNKNOWN */ + +ZSTD_DEPRECATED("The buffer-less API is deprecated in favor of the normal streaming API. See docs.") ZSTDLIB_STATIC_API size_t ZSTD_compressContinue(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize); +ZSTD_DEPRECATED("The buffer-less API is deprecated in favor of the normal streaming API. See docs.") ZSTDLIB_STATIC_API size_t ZSTD_compressEnd(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize); /* The ZSTD_compressBegin_advanced() and ZSTD_compressBegin_usingCDict_advanced() are now DEPRECATED and will generate a compiler warning */ ZSTD_DEPRECATED("use advanced API to access custom parameters") +ZSTDLIB_STATIC_API size_t ZSTD_compressBegin_advanced(ZSTD_CCtx* cctx, const void* dict, size_t dictSize, ZSTD_parameters params, unsigned long long pledgedSrcSize); /*< pledgedSrcSize : If srcSize is not known at init time, use ZSTD_CONTENTSIZE_UNKNOWN */ ZSTD_DEPRECATED("use advanced API to access custom parameters") +ZSTDLIB_STATIC_API size_t ZSTD_compressBegin_usingCDict_advanced(ZSTD_CCtx* const cctx, const ZSTD_CDict* const cdict, ZSTD_frameParameters const fParams, unsigned long long const pledgedSrcSize); /* compression parameters are already set within cdict. pledgedSrcSize must be correct. If srcSize is not known, use macro ZSTD_CONTENTSIZE_UNKNOWN */ /* Buffer-less streaming decompression (synchronous mode) A ZSTD_DCtx object is required to track streaming operations. Use ZSTD_createDCtx() / ZSTD_freeDCtx() to manage it. - A ZSTD_DCtx object can be re-used multiple times. + A ZSTD_DCtx object can be reused multiple times. First typical operation is to retrieve frame parameters, using ZSTD_getFrameHeader(). Frame header is extracted from the beginning of compressed frame, so providing only the frame's beginning is enough. Data fragment must be large enough to ensure successful decoding. `ZSTD_frameHeaderSize_max` bytes is guaranteed to always be large enough. - @result : 0 : successful decoding, the `ZSTD_frameHeader` structure is correctly filled. - >0 : `srcSize` is too small, please provide at least @result bytes on next attempt. + result : 0 : successful decoding, the `ZSTD_frameHeader` structure is correctly filled. + >0 : `srcSize` is too small, please provide at least result bytes on next attempt. errorCode, which can be tested using ZSTD_isError(). - It fills a ZSTD_frameHeader structure with important information to correctly decode the frame, + It fills a ZSTD_FrameHeader structure with important information to correctly decode the frame, such as the dictionary ID, content size, or maximum back-reference distance (`windowSize`). Note that these values could be wrong, either because of data corruption, or because a 3rd party deliberately spoofs false information. As a consequence, check that values remain within valid application range. @@ -2428,7 +3042,7 @@ size_t ZSTD_compressBegin_usingCDict_advanced(ZSTD_CCtx* const cctx, const ZSTD_ The most memory efficient way is to use a round buffer of sufficient size. Sufficient size is determined by invoking ZSTD_decodingBufferSize_min(), - which can @return an error code if required value is too large for current system (in 32-bits mode). + which can return an error code if required value is too large for current system (in 32-bits mode). In a round buffer methodology, ZSTD_decompressContinue() decompresses each block next to previous one, up to the moment there is not enough room left in the buffer to guarantee decoding another full block, which maximum size is provided in `ZSTD_frameHeader` structure, field `blockSizeMax`. @@ -2448,7 +3062,7 @@ size_t ZSTD_compressBegin_usingCDict_advanced(ZSTD_CCtx* const cctx, const ZSTD_ ZSTD_nextSrcSizeToDecompress() tells how many bytes to provide as 'srcSize' to ZSTD_decompressContinue(). ZSTD_decompressContinue() requires this _exact_ amount of bytes, or it will fail. - @result of ZSTD_decompressContinue() is the number of bytes regenerated within 'dst' (necessarily <= dstCapacity). + result of ZSTD_decompressContinue() is the number of bytes regenerated within 'dst' (necessarily <= dstCapacity). It can be zero : it just means ZSTD_decompressContinue() has decoded some metadata item. It can also be an error code, which can be tested with ZSTD_isError(). @@ -2471,27 +3085,7 @@ size_t ZSTD_compressBegin_usingCDict_advanced(ZSTD_CCtx* const cctx, const ZSTD_ */ /*===== Buffer-less streaming decompression functions =====*/ -typedef enum { ZSTD_frame, ZSTD_skippableFrame } ZSTD_frameType_e; -typedef struct { - unsigned long long frameContentSize; /* if == ZSTD_CONTENTSIZE_UNKNOWN, it means this field is not available. 0 means "empty" */ - unsigned long long windowSize; /* can be very large, up to <= frameContentSize */ - unsigned blockSizeMax; - ZSTD_frameType_e frameType; /* if == ZSTD_skippableFrame, frameContentSize is the size of skippable content */ - unsigned headerSize; - unsigned dictID; - unsigned checksumFlag; -} ZSTD_frameHeader; -/*! ZSTD_getFrameHeader() : - * decode Frame Header, or requires larger `srcSize`. - * @return : 0, `zfhPtr` is correctly filled, - * >0, `srcSize` is too small, value is wanted `srcSize` amount, - * or an error code, which can be tested using ZSTD_isError() */ -ZSTDLIB_STATIC_API size_t ZSTD_getFrameHeader(ZSTD_frameHeader* zfhPtr, const void* src, size_t srcSize); /*< doesn't consume input */ -/*! ZSTD_getFrameHeader_advanced() : - * same as ZSTD_getFrameHeader(), - * with added capability to select a format (like ZSTD_f_zstd1_magicless) */ -ZSTDLIB_STATIC_API size_t ZSTD_getFrameHeader_advanced(ZSTD_frameHeader* zfhPtr, const void* src, size_t srcSize, ZSTD_format_e format); ZSTDLIB_STATIC_API size_t ZSTD_decodingBufferSize_min(unsigned long long windowSize, unsigned long long frameContentSize); /*< when frame content size is not known, pass in frameContentSize == ZSTD_CONTENTSIZE_UNKNOWN */ ZSTDLIB_STATIC_API size_t ZSTD_decompressBegin(ZSTD_DCtx* dctx); @@ -2502,6 +3096,7 @@ ZSTDLIB_STATIC_API size_t ZSTD_nextSrcSizeToDecompress(ZSTD_DCtx* dctx); ZSTDLIB_STATIC_API size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize); /* misc */ +ZSTD_DEPRECATED("This function will likely be removed in the next minor release. It is misleading and has very limited utility.") ZSTDLIB_STATIC_API void ZSTD_copyDCtx(ZSTD_DCtx* dctx, const ZSTD_DCtx* preparedDCtx); typedef enum { ZSTDnit_frameHeader, ZSTDnit_blockHeader, ZSTDnit_block, ZSTDnit_lastBlock, ZSTDnit_checksum, ZSTDnit_skippableFrame } ZSTD_nextInputType_e; ZSTDLIB_STATIC_API ZSTD_nextInputType_e ZSTD_nextInputType(ZSTD_DCtx* dctx); @@ -2509,11 +3104,23 @@ ZSTDLIB_STATIC_API ZSTD_nextInputType_e ZSTD_nextInputType(ZSTD_DCtx* dctx); -/* ============================ */ -/* Block level API */ -/* ============================ */ +/* ========================================= */ +/* Block level API (DEPRECATED) */ +/* ========================================= */ /*! + + This API is deprecated in favor of the regular compression API. + You can get the frame header down to 2 bytes by setting: + - ZSTD_c_format = ZSTD_f_zstd1_magicless + - ZSTD_c_contentSizeFlag = 0 + - ZSTD_c_checksumFlag = 0 + - ZSTD_c_dictIDFlag = 0 + + This API is not as well tested as our normal API, so we recommend not using it. + We will be removing it in a future version. If the normal API doesn't provide + the functionality you need, please open a GitHub issue. + Block functions produce and decode raw zstd blocks, without frame metadata. Frame metadata cost is typically ~12 bytes, which can be non-negligible for very small blocks (< 100 bytes). But users will have to take in charge needed metadata to regenerate data, such as compressed and content sizes. @@ -2524,7 +3131,6 @@ ZSTDLIB_STATIC_API ZSTD_nextInputType_e ZSTD_nextInputType(ZSTD_DCtx* dctx); - It is necessary to init context before starting + compression : any ZSTD_compressBegin*() variant, including with dictionary + decompression : any ZSTD_decompressBegin*() variant, including with dictionary - + copyCCtx() and copyDCtx() can be used too - Block size is limited, it must be <= ZSTD_getBlockSize() <= ZSTD_BLOCKSIZE_MAX == 128 KB + If input is larger than a block size, it's necessary to split input data into multiple blocks + For inputs larger than a single block, consider using regular ZSTD_compress() instead. @@ -2541,11 +3147,14 @@ ZSTDLIB_STATIC_API ZSTD_nextInputType_e ZSTD_nextInputType(ZSTD_DCtx* dctx); */ /*===== Raw zstd block functions =====*/ +ZSTD_DEPRECATED("The block API is deprecated in favor of the normal compression API. See docs.") ZSTDLIB_STATIC_API size_t ZSTD_getBlockSize (const ZSTD_CCtx* cctx); +ZSTD_DEPRECATED("The block API is deprecated in favor of the normal compression API. See docs.") ZSTDLIB_STATIC_API size_t ZSTD_compressBlock (ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize); +ZSTD_DEPRECATED("The block API is deprecated in favor of the normal compression API. See docs.") ZSTDLIB_STATIC_API size_t ZSTD_decompressBlock(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize); +ZSTD_DEPRECATED("The block API is deprecated in favor of the normal compression API. See docs.") ZSTDLIB_STATIC_API size_t ZSTD_insertBlock (ZSTD_DCtx* dctx, const void* blockStart, size_t blockSize); /*< insert uncompressed block into `dctx` history. Useful for multi-blocks decompression. */ #endif /* ZSTD_H_ZSTD_STATIC_LINKING_ONLY */ - diff --git a/lib/zstd/Makefile b/lib/zstd/Makefile index 20f08c644b71..be218b5e0ed5 100644 --- a/lib/zstd/Makefile +++ b/lib/zstd/Makefile @@ -1,6 +1,6 @@ # SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause # ################################################################ -# Copyright (c) Facebook, Inc. +# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under both the BSD-style license (found in the @@ -26,6 +26,7 @@ zstd_compress-y := \ compress/zstd_lazy.o \ compress/zstd_ldm.o \ compress/zstd_opt.o \ + compress/zstd_preSplit.o \ zstd_decompress-y := \ zstd_decompress_module.o \ diff --git a/lib/zstd/common/allocations.h b/lib/zstd/common/allocations.h new file mode 100644 index 000000000000..16c3d08e8d1a --- /dev/null +++ b/lib/zstd/common/allocations.h @@ -0,0 +1,56 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the + * LICENSE file in the root directory of this source tree) and the GPLv2 (found + * in the COPYING file in the root directory of this source tree). + * You may select, at your option, one of the above-listed licenses. + */ + +/* This file provides custom allocation primitives + */ + +#define ZSTD_DEPS_NEED_MALLOC +#include "zstd_deps.h" /* ZSTD_malloc, ZSTD_calloc, ZSTD_free, ZSTD_memset */ + +#include "compiler.h" /* MEM_STATIC */ +#define ZSTD_STATIC_LINKING_ONLY +#include /* ZSTD_customMem */ + +#ifndef ZSTD_ALLOCATIONS_H +#define ZSTD_ALLOCATIONS_H + +/* custom memory allocation functions */ + +MEM_STATIC void* ZSTD_customMalloc(size_t size, ZSTD_customMem customMem) +{ + if (customMem.customAlloc) + return customMem.customAlloc(customMem.opaque, size); + return ZSTD_malloc(size); +} + +MEM_STATIC void* ZSTD_customCalloc(size_t size, ZSTD_customMem customMem) +{ + if (customMem.customAlloc) { + /* calloc implemented as malloc+memset; + * not as efficient as calloc, but next best guess for custom malloc */ + void* const ptr = customMem.customAlloc(customMem.opaque, size); + ZSTD_memset(ptr, 0, size); + return ptr; + } + return ZSTD_calloc(1, size); +} + +MEM_STATIC void ZSTD_customFree(void* ptr, ZSTD_customMem customMem) +{ + if (ptr!=NULL) { + if (customMem.customFree) + customMem.customFree(customMem.opaque, ptr); + else + ZSTD_free(ptr); + } +} + +#endif /* ZSTD_ALLOCATIONS_H */ diff --git a/lib/zstd/common/bits.h b/lib/zstd/common/bits.h new file mode 100644 index 000000000000..c5faaa3d7b08 --- /dev/null +++ b/lib/zstd/common/bits.h @@ -0,0 +1,150 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the + * LICENSE file in the root directory of this source tree) and the GPLv2 (found + * in the COPYING file in the root directory of this source tree). + * You may select, at your option, one of the above-listed licenses. + */ + +#ifndef ZSTD_BITS_H +#define ZSTD_BITS_H + +#include "mem.h" + +MEM_STATIC unsigned ZSTD_countTrailingZeros32_fallback(U32 val) +{ + assert(val != 0); + { + static const U32 DeBruijnBytePos[32] = {0, 1, 28, 2, 29, 14, 24, 3, + 30, 22, 20, 15, 25, 17, 4, 8, + 31, 27, 13, 23, 21, 19, 16, 7, + 26, 12, 18, 6, 11, 5, 10, 9}; + return DeBruijnBytePos[((U32) ((val & -(S32) val) * 0x077CB531U)) >> 27]; + } +} + +MEM_STATIC unsigned ZSTD_countTrailingZeros32(U32 val) +{ + assert(val != 0); +#if (__GNUC__ >= 4) + return (unsigned)__builtin_ctz(val); +#else + return ZSTD_countTrailingZeros32_fallback(val); +#endif +} + +MEM_STATIC unsigned ZSTD_countLeadingZeros32_fallback(U32 val) +{ + assert(val != 0); + { + static const U32 DeBruijnClz[32] = {0, 9, 1, 10, 13, 21, 2, 29, + 11, 14, 16, 18, 22, 25, 3, 30, + 8, 12, 20, 28, 15, 17, 24, 7, + 19, 27, 23, 6, 26, 5, 4, 31}; + val |= val >> 1; + val |= val >> 2; + val |= val >> 4; + val |= val >> 8; + val |= val >> 16; + return 31 - DeBruijnClz[(val * 0x07C4ACDDU) >> 27]; + } +} + +MEM_STATIC unsigned ZSTD_countLeadingZeros32(U32 val) +{ + assert(val != 0); +#if (__GNUC__ >= 4) + return (unsigned)__builtin_clz(val); +#else + return ZSTD_countLeadingZeros32_fallback(val); +#endif +} + +MEM_STATIC unsigned ZSTD_countTrailingZeros64(U64 val) +{ + assert(val != 0); +#if (__GNUC__ >= 4) && defined(__LP64__) + return (unsigned)__builtin_ctzll(val); +#else + { + U32 mostSignificantWord = (U32)(val >> 32); + U32 leastSignificantWord = (U32)val; + if (leastSignificantWord == 0) { + return 32 + ZSTD_countTrailingZeros32(mostSignificantWord); + } else { + return ZSTD_countTrailingZeros32(leastSignificantWord); + } + } +#endif +} + +MEM_STATIC unsigned ZSTD_countLeadingZeros64(U64 val) +{ + assert(val != 0); +#if (__GNUC__ >= 4) + return (unsigned)(__builtin_clzll(val)); +#else + { + U32 mostSignificantWord = (U32)(val >> 32); + U32 leastSignificantWord = (U32)val; + if (mostSignificantWord == 0) { + return 32 + ZSTD_countLeadingZeros32(leastSignificantWord); + } else { + return ZSTD_countLeadingZeros32(mostSignificantWord); + } + } +#endif +} + +MEM_STATIC unsigned ZSTD_NbCommonBytes(size_t val) +{ + if (MEM_isLittleEndian()) { + if (MEM_64bits()) { + return ZSTD_countTrailingZeros64((U64)val) >> 3; + } else { + return ZSTD_countTrailingZeros32((U32)val) >> 3; + } + } else { /* Big Endian CPU */ + if (MEM_64bits()) { + return ZSTD_countLeadingZeros64((U64)val) >> 3; + } else { + return ZSTD_countLeadingZeros32((U32)val) >> 3; + } + } +} + +MEM_STATIC unsigned ZSTD_highbit32(U32 val) /* compress, dictBuilder, decodeCorpus */ +{ + assert(val != 0); + return 31 - ZSTD_countLeadingZeros32(val); +} + +/* ZSTD_rotateRight_*(): + * Rotates a bitfield to the right by "count" bits. + * https://en.wikipedia.org/w/index.php?title=Circular_shift&oldid=991635599#Implementing_circular_shifts + */ +MEM_STATIC +U64 ZSTD_rotateRight_U64(U64 const value, U32 count) { + assert(count < 64); + count &= 0x3F; /* for fickle pattern recognition */ + return (value >> count) | (U64)(value << ((0U - count) & 0x3F)); +} + +MEM_STATIC +U32 ZSTD_rotateRight_U32(U32 const value, U32 count) { + assert(count < 32); + count &= 0x1F; /* for fickle pattern recognition */ + return (value >> count) | (U32)(value << ((0U - count) & 0x1F)); +} + +MEM_STATIC +U16 ZSTD_rotateRight_U16(U16 const value, U32 count) { + assert(count < 16); + count &= 0x0F; /* for fickle pattern recognition */ + return (value >> count) | (U16)(value << ((0U - count) & 0x0F)); +} + +#endif /* ZSTD_BITS_H */ diff --git a/lib/zstd/common/bitstream.h b/lib/zstd/common/bitstream.h index feef3a1b1d60..86439da0eea7 100644 --- a/lib/zstd/common/bitstream.h +++ b/lib/zstd/common/bitstream.h @@ -1,7 +1,8 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* ****************************************************************** * bitstream * Part of FSE library - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * * You can contact the author at : * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy @@ -27,7 +28,7 @@ #include "compiler.h" /* UNLIKELY() */ #include "debug.h" /* assert(), DEBUGLOG(), RAWLOG() */ #include "error_private.h" /* error codes and messages */ - +#include "bits.h" /* ZSTD_highbit32 */ /*========================================= * Target specific @@ -41,12 +42,13 @@ /*-****************************************** * bitStream encoding API (write forward) ********************************************/ +typedef size_t BitContainerType; /* bitStream can mix input from multiple sources. * A critical property of these streams is that they encode and decode in **reverse** direction. * So the first bit sequence you add will be the last to be read, like a LIFO stack. */ typedef struct { - size_t bitContainer; + BitContainerType bitContainer; unsigned bitPos; char* startPtr; char* ptr; @@ -54,7 +56,7 @@ typedef struct { } BIT_CStream_t; MEM_STATIC size_t BIT_initCStream(BIT_CStream_t* bitC, void* dstBuffer, size_t dstCapacity); -MEM_STATIC void BIT_addBits(BIT_CStream_t* bitC, size_t value, unsigned nbBits); +MEM_STATIC void BIT_addBits(BIT_CStream_t* bitC, BitContainerType value, unsigned nbBits); MEM_STATIC void BIT_flushBits(BIT_CStream_t* bitC); MEM_STATIC size_t BIT_closeCStream(BIT_CStream_t* bitC); @@ -63,7 +65,7 @@ MEM_STATIC size_t BIT_closeCStream(BIT_CStream_t* bitC); * `dstCapacity` must be >= sizeof(bitD->bitContainer), otherwise @return will be an error code. * * bits are first added to a local register. -* Local register is size_t, hence 64-bits on 64-bits systems, or 32-bits on 32-bits systems. +* Local register is BitContainerType, 64-bits on 64-bits systems, or 32-bits on 32-bits systems. * Writing data into memory is an explicit operation, performed by the flushBits function. * Hence keep track how many bits are potentially stored into local register to avoid register overflow. * After a flushBits, a maximum of 7 bits might still be stored into local register. @@ -80,28 +82,28 @@ MEM_STATIC size_t BIT_closeCStream(BIT_CStream_t* bitC); * bitStream decoding API (read backward) **********************************************/ typedef struct { - size_t bitContainer; + BitContainerType bitContainer; unsigned bitsConsumed; const char* ptr; const char* start; const char* limitPtr; } BIT_DStream_t; -typedef enum { BIT_DStream_unfinished = 0, - BIT_DStream_endOfBuffer = 1, - BIT_DStream_completed = 2, - BIT_DStream_overflow = 3 } BIT_DStream_status; /* result of BIT_reloadDStream() */ - /* 1,2,4,8 would be better for bitmap combinations, but slows down performance a bit ... :( */ +typedef enum { BIT_DStream_unfinished = 0, /* fully refilled */ + BIT_DStream_endOfBuffer = 1, /* still some bits left in bitstream */ + BIT_DStream_completed = 2, /* bitstream entirely consumed, bit-exact */ + BIT_DStream_overflow = 3 /* user requested more bits than present in bitstream */ + } BIT_DStream_status; /* result of BIT_reloadDStream() */ MEM_STATIC size_t BIT_initDStream(BIT_DStream_t* bitD, const void* srcBuffer, size_t srcSize); -MEM_STATIC size_t BIT_readBits(BIT_DStream_t* bitD, unsigned nbBits); +MEM_STATIC BitContainerType BIT_readBits(BIT_DStream_t* bitD, unsigned nbBits); MEM_STATIC BIT_DStream_status BIT_reloadDStream(BIT_DStream_t* bitD); MEM_STATIC unsigned BIT_endOfDStream(const BIT_DStream_t* bitD); /* Start by invoking BIT_initDStream(). * A chunk of the bitStream is then stored into a local register. -* Local register size is 64-bits on 64-bits systems, 32-bits on 32-bits systems (size_t). +* Local register size is 64-bits on 64-bits systems, 32-bits on 32-bits systems (BitContainerType). * You can then retrieve bitFields stored into the local register, **in reverse order**. * Local register is explicitly reloaded from memory by the BIT_reloadDStream() method. * A reload guarantee a minimum of ((8*sizeof(bitD->bitContainer))-7) bits when its result is BIT_DStream_unfinished. @@ -113,7 +115,7 @@ MEM_STATIC unsigned BIT_endOfDStream(const BIT_DStream_t* bitD); /*-**************************************** * unsafe API ******************************************/ -MEM_STATIC void BIT_addBitsFast(BIT_CStream_t* bitC, size_t value, unsigned nbBits); +MEM_STATIC void BIT_addBitsFast(BIT_CStream_t* bitC, BitContainerType value, unsigned nbBits); /* faster, but works only if value is "clean", meaning all high bits above nbBits are 0 */ MEM_STATIC void BIT_flushBitsFast(BIT_CStream_t* bitC); @@ -122,33 +124,6 @@ MEM_STATIC void BIT_flushBitsFast(BIT_CStream_t* bitC); MEM_STATIC size_t BIT_readBitsFast(BIT_DStream_t* bitD, unsigned nbBits); /* faster, but works only if nbBits >= 1 */ - - -/*-************************************************************** -* Internal functions -****************************************************************/ -MEM_STATIC unsigned BIT_highbit32 (U32 val) -{ - assert(val != 0); - { -# if (__GNUC__ >= 3) /* Use GCC Intrinsic */ - return __builtin_clz (val) ^ 31; -# else /* Software version */ - static const unsigned DeBruijnClz[32] = { 0, 9, 1, 10, 13, 21, 2, 29, - 11, 14, 16, 18, 22, 25, 3, 30, - 8, 12, 20, 28, 15, 17, 24, 7, - 19, 27, 23, 6, 26, 5, 4, 31 }; - U32 v = val; - v |= v >> 1; - v |= v >> 2; - v |= v >> 4; - v |= v >> 8; - v |= v >> 16; - return DeBruijnClz[ (U32) (v * 0x07C4ACDDU) >> 27]; -# endif - } -} - /*===== Local Constants =====*/ static const unsigned BIT_mask[] = { 0, 1, 3, 7, 0xF, 0x1F, @@ -178,16 +153,22 @@ MEM_STATIC size_t BIT_initCStream(BIT_CStream_t* bitC, return 0; } +FORCE_INLINE_TEMPLATE BitContainerType BIT_getLowerBits(BitContainerType bitContainer, U32 const nbBits) +{ + assert(nbBits < BIT_MASK_SIZE); + return bitContainer & BIT_mask[nbBits]; +} + /*! BIT_addBits() : * can add up to 31 bits into `bitC`. * Note : does not check for register overflow ! */ MEM_STATIC void BIT_addBits(BIT_CStream_t* bitC, - size_t value, unsigned nbBits) + BitContainerType value, unsigned nbBits) { DEBUG_STATIC_ASSERT(BIT_MASK_SIZE == 32); assert(nbBits < BIT_MASK_SIZE); assert(nbBits + bitC->bitPos < sizeof(bitC->bitContainer) * 8); - bitC->bitContainer |= (value & BIT_mask[nbBits]) << bitC->bitPos; + bitC->bitContainer |= BIT_getLowerBits(value, nbBits) << bitC->bitPos; bitC->bitPos += nbBits; } @@ -195,7 +176,7 @@ MEM_STATIC void BIT_addBits(BIT_CStream_t* bitC, * works only if `value` is _clean_, * meaning all high bits above nbBits are 0 */ MEM_STATIC void BIT_addBitsFast(BIT_CStream_t* bitC, - size_t value, unsigned nbBits) + BitContainerType value, unsigned nbBits) { assert((value>>nbBits) == 0); assert(nbBits + bitC->bitPos < sizeof(bitC->bitContainer) * 8); @@ -242,7 +223,7 @@ MEM_STATIC size_t BIT_closeCStream(BIT_CStream_t* bitC) BIT_addBitsFast(bitC, 1, 1); /* endMark */ BIT_flushBits(bitC); if (bitC->ptr >= bitC->endPtr) return 0; /* overflow detected */ - return (bitC->ptr - bitC->startPtr) + (bitC->bitPos > 0); + return (size_t)(bitC->ptr - bitC->startPtr) + (bitC->bitPos > 0); } @@ -266,35 +247,35 @@ MEM_STATIC size_t BIT_initDStream(BIT_DStream_t* bitD, const void* srcBuffer, si bitD->ptr = (const char*)srcBuffer + srcSize - sizeof(bitD->bitContainer); bitD->bitContainer = MEM_readLEST(bitD->ptr); { BYTE const lastByte = ((const BYTE*)srcBuffer)[srcSize-1]; - bitD->bitsConsumed = lastByte ? 8 - BIT_highbit32(lastByte) : 0; /* ensures bitsConsumed is always set */ + bitD->bitsConsumed = lastByte ? 8 - ZSTD_highbit32(lastByte) : 0; /* ensures bitsConsumed is always set */ if (lastByte == 0) return ERROR(GENERIC); /* endMark not present */ } } else { bitD->ptr = bitD->start; bitD->bitContainer = *(const BYTE*)(bitD->start); switch(srcSize) { - case 7: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[6]) << (sizeof(bitD->bitContainer)*8 - 16); + case 7: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[6]) << (sizeof(bitD->bitContainer)*8 - 16); ZSTD_FALLTHROUGH; - case 6: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[5]) << (sizeof(bitD->bitContainer)*8 - 24); + case 6: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[5]) << (sizeof(bitD->bitContainer)*8 - 24); ZSTD_FALLTHROUGH; - case 5: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[4]) << (sizeof(bitD->bitContainer)*8 - 32); + case 5: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[4]) << (sizeof(bitD->bitContainer)*8 - 32); ZSTD_FALLTHROUGH; - case 4: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[3]) << 24; + case 4: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[3]) << 24; ZSTD_FALLTHROUGH; - case 3: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[2]) << 16; + case 3: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[2]) << 16; ZSTD_FALLTHROUGH; - case 2: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[1]) << 8; + case 2: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[1]) << 8; ZSTD_FALLTHROUGH; default: break; } { BYTE const lastByte = ((const BYTE*)srcBuffer)[srcSize-1]; - bitD->bitsConsumed = lastByte ? 8 - BIT_highbit32(lastByte) : 0; + bitD->bitsConsumed = lastByte ? 8 - ZSTD_highbit32(lastByte) : 0; if (lastByte == 0) return ERROR(corruption_detected); /* endMark not present */ } bitD->bitsConsumed += (U32)(sizeof(bitD->bitContainer) - srcSize)*8; @@ -303,12 +284,12 @@ MEM_STATIC size_t BIT_initDStream(BIT_DStream_t* bitD, const void* srcBuffer, si return srcSize; } -MEM_STATIC FORCE_INLINE_ATTR size_t BIT_getUpperBits(size_t bitContainer, U32 const start) +FORCE_INLINE_TEMPLATE BitContainerType BIT_getUpperBits(BitContainerType bitContainer, U32 const start) { return bitContainer >> start; } -MEM_STATIC FORCE_INLINE_ATTR size_t BIT_getMiddleBits(size_t bitContainer, U32 const start, U32 const nbBits) +FORCE_INLINE_TEMPLATE BitContainerType BIT_getMiddleBits(BitContainerType bitContainer, U32 const start, U32 const nbBits) { U32 const regMask = sizeof(bitContainer)*8 - 1; /* if start > regMask, bitstream is corrupted, and result is undefined */ @@ -318,26 +299,20 @@ MEM_STATIC FORCE_INLINE_ATTR size_t BIT_getMiddleBits(size_t bitContainer, U32 c * such cpus old (pre-Haswell, 2013) and their performance is not of that * importance. */ -#if defined(__x86_64__) || defined(_M_X86) +#if defined(__x86_64__) || defined(_M_X64) return (bitContainer >> (start & regMask)) & ((((U64)1) << nbBits) - 1); #else return (bitContainer >> (start & regMask)) & BIT_mask[nbBits]; #endif } -MEM_STATIC FORCE_INLINE_ATTR size_t BIT_getLowerBits(size_t bitContainer, U32 const nbBits) -{ - assert(nbBits < BIT_MASK_SIZE); - return bitContainer & BIT_mask[nbBits]; -} - /*! BIT_lookBits() : * Provides next n bits from local register. * local register is not modified. * On 32-bits, maxNbBits==24. * On 64-bits, maxNbBits==56. * @return : value extracted */ -MEM_STATIC FORCE_INLINE_ATTR size_t BIT_lookBits(const BIT_DStream_t* bitD, U32 nbBits) +FORCE_INLINE_TEMPLATE BitContainerType BIT_lookBits(const BIT_DStream_t* bitD, U32 nbBits) { /* arbitrate between double-shift and shift+mask */ #if 1 @@ -353,14 +328,14 @@ MEM_STATIC FORCE_INLINE_ATTR size_t BIT_lookBits(const BIT_DStream_t* bitD, U3 /*! BIT_lookBitsFast() : * unsafe version; only works if nbBits >= 1 */ -MEM_STATIC size_t BIT_lookBitsFast(const BIT_DStream_t* bitD, U32 nbBits) +MEM_STATIC BitContainerType BIT_lookBitsFast(const BIT_DStream_t* bitD, U32 nbBits) { U32 const regMask = sizeof(bitD->bitContainer)*8 - 1; assert(nbBits >= 1); return (bitD->bitContainer << (bitD->bitsConsumed & regMask)) >> (((regMask+1)-nbBits) & regMask); } -MEM_STATIC FORCE_INLINE_ATTR void BIT_skipBits(BIT_DStream_t* bitD, U32 nbBits) +FORCE_INLINE_TEMPLATE void BIT_skipBits(BIT_DStream_t* bitD, U32 nbBits) { bitD->bitsConsumed += nbBits; } @@ -369,23 +344,38 @@ MEM_STATIC FORCE_INLINE_ATTR void BIT_skipBits(BIT_DStream_t* bitD, U32 nbBits) * Read (consume) next n bits from local register and update. * Pay attention to not read more than nbBits contained into local register. * @return : extracted value. */ -MEM_STATIC FORCE_INLINE_ATTR size_t BIT_readBits(BIT_DStream_t* bitD, unsigned nbBits) +FORCE_INLINE_TEMPLATE BitContainerType BIT_readBits(BIT_DStream_t* bitD, unsigned nbBits) { - size_t const value = BIT_lookBits(bitD, nbBits); + BitContainerType const value = BIT_lookBits(bitD, nbBits); BIT_skipBits(bitD, nbBits); return value; } /*! BIT_readBitsFast() : - * unsafe version; only works only if nbBits >= 1 */ -MEM_STATIC size_t BIT_readBitsFast(BIT_DStream_t* bitD, unsigned nbBits) + * unsafe version; only works if nbBits >= 1 */ +MEM_STATIC BitContainerType BIT_readBitsFast(BIT_DStream_t* bitD, unsigned nbBits) { - size_t const value = BIT_lookBitsFast(bitD, nbBits); + BitContainerType const value = BIT_lookBitsFast(bitD, nbBits); assert(nbBits >= 1); BIT_skipBits(bitD, nbBits); return value; } +/*! BIT_reloadDStream_internal() : + * Simple variant of BIT_reloadDStream(), with two conditions: + * 1. bitstream is valid : bitsConsumed <= sizeof(bitD->bitContainer)*8 + * 2. look window is valid after shifted down : bitD->ptr >= bitD->start + */ +MEM_STATIC BIT_DStream_status BIT_reloadDStream_internal(BIT_DStream_t* bitD) +{ + assert(bitD->bitsConsumed <= sizeof(bitD->bitContainer)*8); + bitD->ptr -= bitD->bitsConsumed >> 3; + assert(bitD->ptr >= bitD->start); + bitD->bitsConsumed &= 7; + bitD->bitContainer = MEM_readLEST(bitD->ptr); + return BIT_DStream_unfinished; +} + /*! BIT_reloadDStreamFast() : * Similar to BIT_reloadDStream(), but with two differences: * 1. bitsConsumed <= sizeof(bitD->bitContainer)*8 must hold! @@ -396,31 +386,35 @@ MEM_STATIC BIT_DStream_status BIT_reloadDStreamFast(BIT_DStream_t* bitD) { if (UNLIKELY(bitD->ptr < bitD->limitPtr)) return BIT_DStream_overflow; - assert(bitD->bitsConsumed <= sizeof(bitD->bitContainer)*8); - bitD->ptr -= bitD->bitsConsumed >> 3; - bitD->bitsConsumed &= 7; - bitD->bitContainer = MEM_readLEST(bitD->ptr); - return BIT_DStream_unfinished; + return BIT_reloadDStream_internal(bitD); } /*! BIT_reloadDStream() : * Refill `bitD` from buffer previously set in BIT_initDStream() . - * This function is safe, it guarantees it will not read beyond src buffer. + * This function is safe, it guarantees it will not never beyond src buffer. * @return : status of `BIT_DStream_t` internal register. * when status == BIT_DStream_unfinished, internal register is filled with at least 25 or 57 bits */ -MEM_STATIC BIT_DStream_status BIT_reloadDStream(BIT_DStream_t* bitD) +FORCE_INLINE_TEMPLATE BIT_DStream_status BIT_reloadDStream(BIT_DStream_t* bitD) { - if (bitD->bitsConsumed > (sizeof(bitD->bitContainer)*8)) /* overflow detected, like end of stream */ + /* note : once in overflow mode, a bitstream remains in this mode until it's reset */ + if (UNLIKELY(bitD->bitsConsumed > (sizeof(bitD->bitContainer)*8))) { + static const BitContainerType zeroFilled = 0; + bitD->ptr = (const char*)&zeroFilled; /* aliasing is allowed for char */ + /* overflow detected, erroneous scenario or end of stream: no update */ return BIT_DStream_overflow; + } + + assert(bitD->ptr >= bitD->start); if (bitD->ptr >= bitD->limitPtr) { - return BIT_reloadDStreamFast(bitD); + return BIT_reloadDStream_internal(bitD); } if (bitD->ptr == bitD->start) { + /* reached end of bitStream => no update */ if (bitD->bitsConsumed < sizeof(bitD->bitContainer)*8) return BIT_DStream_endOfBuffer; return BIT_DStream_completed; } - /* start < ptr < limitPtr */ + /* start < ptr < limitPtr => cautious update */ { U32 nbBytes = bitD->bitsConsumed >> 3; BIT_DStream_status result = BIT_DStream_unfinished; if (bitD->ptr - nbBytes < bitD->start) { @@ -442,5 +436,4 @@ MEM_STATIC unsigned BIT_endOfDStream(const BIT_DStream_t* DStream) return ((DStream->ptr == DStream->start) && (DStream->bitsConsumed == sizeof(DStream->bitContainer)*8)); } - #endif /* BITSTREAM_H_MODULE */ diff --git a/lib/zstd/common/compiler.h b/lib/zstd/common/compiler.h index c42d39faf9bd..dc9bd15e174e 100644 --- a/lib/zstd/common/compiler.h +++ b/lib/zstd/common/compiler.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -11,6 +12,8 @@ #ifndef ZSTD_COMPILER_H #define ZSTD_COMPILER_H +#include + #include "portability_macros.h" /*-******************************************************* @@ -41,12 +44,15 @@ */ #define WIN_CDECL +/* UNUSED_ATTR tells the compiler it is okay if the function is unused. */ +#define UNUSED_ATTR __attribute__((unused)) + /* * FORCE_INLINE_TEMPLATE is used to define C "templates", which take constant * parameters. They must be inlined for the compiler to eliminate the constant * branches. */ -#define FORCE_INLINE_TEMPLATE static INLINE_KEYWORD FORCE_INLINE_ATTR +#define FORCE_INLINE_TEMPLATE static INLINE_KEYWORD FORCE_INLINE_ATTR UNUSED_ATTR /* * HINT_INLINE is used to help the compiler generate better code. It is *not* * used for "templates", so it can be tweaked based on the compilers @@ -61,11 +67,21 @@ #if !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 4 && __GNUC_MINOR__ >= 8 && __GNUC__ < 5 # define HINT_INLINE static INLINE_KEYWORD #else -# define HINT_INLINE static INLINE_KEYWORD FORCE_INLINE_ATTR +# define HINT_INLINE FORCE_INLINE_TEMPLATE #endif -/* UNUSED_ATTR tells the compiler it is okay if the function is unused. */ -#define UNUSED_ATTR __attribute__((unused)) +/* "soft" inline : + * The compiler is free to select if it's a good idea to inline or not. + * The main objective is to silence compiler warnings + * when a defined function in included but not used. + * + * Note : this macro is prefixed `MEM_` because it used to be provided by `mem.h` unit. + * Updating the prefix is probably preferable, but requires a fairly large codemod, + * since this name is used everywhere. + */ +#ifndef MEM_STATIC /* already defined in Linux Kernel mem.h */ +#define MEM_STATIC static __inline UNUSED_ATTR +#endif /* force no inlining */ #define FORCE_NOINLINE static __attribute__((__noinline__)) @@ -86,23 +102,24 @@ # define PREFETCH_L1(ptr) __builtin_prefetch((ptr), 0 /* rw==read */, 3 /* locality */) # define PREFETCH_L2(ptr) __builtin_prefetch((ptr), 0 /* rw==read */, 2 /* locality */) #elif defined(__aarch64__) -# define PREFETCH_L1(ptr) __asm__ __volatile__("prfm pldl1keep, %0" ::"Q"(*(ptr))) -# define PREFETCH_L2(ptr) __asm__ __volatile__("prfm pldl2keep, %0" ::"Q"(*(ptr))) +# define PREFETCH_L1(ptr) do { __asm__ __volatile__("prfm pldl1keep, %0" ::"Q"(*(ptr))); } while (0) +# define PREFETCH_L2(ptr) do { __asm__ __volatile__("prfm pldl2keep, %0" ::"Q"(*(ptr))); } while (0) #else -# define PREFETCH_L1(ptr) (void)(ptr) /* disabled */ -# define PREFETCH_L2(ptr) (void)(ptr) /* disabled */ +# define PREFETCH_L1(ptr) do { (void)(ptr); } while (0) /* disabled */ +# define PREFETCH_L2(ptr) do { (void)(ptr); } while (0) /* disabled */ #endif /* NO_PREFETCH */ #define CACHELINE_SIZE 64 -#define PREFETCH_AREA(p, s) { \ - const char* const _ptr = (const char*)(p); \ - size_t const _size = (size_t)(s); \ - size_t _pos; \ - for (_pos=0; _pos<_size; _pos+=CACHELINE_SIZE) { \ - PREFETCH_L2(_ptr + _pos); \ - } \ -} +#define PREFETCH_AREA(p, s) \ + do { \ + const char* const _ptr = (const char*)(p); \ + size_t const _size = (size_t)(s); \ + size_t _pos; \ + for (_pos=0; _pos<_size; _pos+=CACHELINE_SIZE) { \ + PREFETCH_L2(_ptr + _pos); \ + } \ + } while (0) /* vectorization * older GCC (pre gcc-4.3 picked as the cutoff) uses a different syntax, @@ -126,16 +143,13 @@ #define UNLIKELY(x) (__builtin_expect((x), 0)) #if __has_builtin(__builtin_unreachable) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 5))) -# define ZSTD_UNREACHABLE { assert(0), __builtin_unreachable(); } +# define ZSTD_UNREACHABLE do { assert(0), __builtin_unreachable(); } while (0) #else -# define ZSTD_UNREACHABLE { assert(0); } +# define ZSTD_UNREACHABLE do { assert(0); } while (0) #endif /* disable warnings */ -/*Like DYNAMIC_BMI2 but for compile time determination of BMI2 support*/ - - /* compile time determination of SIMD support */ /* C-language Attributes are added in C23. */ @@ -158,9 +172,15 @@ #define ZSTD_FALLTHROUGH fallthrough /*-************************************************************** -* Alignment check +* Alignment *****************************************************************/ +/* @return 1 if @u is a 2^n value, 0 otherwise + * useful to check a value is valid for alignment restrictions */ +MEM_STATIC int ZSTD_isPower2(size_t u) { + return (u & (u-1)) == 0; +} + /* this test was initially positioned in mem.h, * but this file is removed (or replaced) for linux kernel * so it's now hosted in compiler.h, @@ -175,10 +195,95 @@ #endif /* ZSTD_ALIGNOF */ +#ifndef ZSTD_ALIGNED +/* C90-compatible alignment macro (GCC/Clang). Adjust for other compilers if needed. */ +#define ZSTD_ALIGNED(a) __attribute__((aligned(a))) +#endif /* ZSTD_ALIGNED */ + + /*-************************************************************** * Sanitizer *****************************************************************/ +/* + * Zstd relies on pointer overflow in its decompressor. + * We add this attribute to functions that rely on pointer overflow. + */ +#ifndef ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +# if __has_attribute(no_sanitize) +# if !defined(__clang__) && defined(__GNUC__) && __GNUC__ < 8 + /* gcc < 8 only has signed-integer-overlow which triggers on pointer overflow */ +# define ZSTD_ALLOW_POINTER_OVERFLOW_ATTR __attribute__((no_sanitize("signed-integer-overflow"))) +# else + /* older versions of clang [3.7, 5.0) will warn that pointer-overflow is ignored. */ +# define ZSTD_ALLOW_POINTER_OVERFLOW_ATTR __attribute__((no_sanitize("pointer-overflow"))) +# endif +# else +# define ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +# endif +#endif + +/* + * Helper function to perform a wrapped pointer difference without triggering + * UBSAN. + * + * @returns lhs - rhs with wrapping + */ +MEM_STATIC +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +ptrdiff_t ZSTD_wrappedPtrDiff(unsigned char const* lhs, unsigned char const* rhs) +{ + return lhs - rhs; +} + +/* + * Helper function to perform a wrapped pointer add without triggering UBSAN. + * + * @return ptr + add with wrapping + */ +MEM_STATIC +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +unsigned char const* ZSTD_wrappedPtrAdd(unsigned char const* ptr, ptrdiff_t add) +{ + return ptr + add; +} + +/* + * Helper function to perform a wrapped pointer subtraction without triggering + * UBSAN. + * + * @return ptr - sub with wrapping + */ +MEM_STATIC +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +unsigned char const* ZSTD_wrappedPtrSub(unsigned char const* ptr, ptrdiff_t sub) +{ + return ptr - sub; +} + +/* + * Helper function to add to a pointer that works around C's undefined behavior + * of adding 0 to NULL. + * + * @returns `ptr + add` except it defines `NULL + 0 == NULL`. + */ +MEM_STATIC +unsigned char* ZSTD_maybeNullPtrAdd(unsigned char* ptr, ptrdiff_t add) +{ + return add > 0 ? ptr + add : ptr; +} + +/* Issue #3240 reports an ASAN failure on an llvm-mingw build. Out of an + * abundance of caution, disable our custom poisoning on mingw. */ +#ifdef __MINGW32__ +#ifndef ZSTD_ASAN_DONT_POISON_WORKSPACE +#define ZSTD_ASAN_DONT_POISON_WORKSPACE 1 +#endif +#ifndef ZSTD_MSAN_DONT_POISON_WORKSPACE +#define ZSTD_MSAN_DONT_POISON_WORKSPACE 1 +#endif +#endif + #endif /* ZSTD_COMPILER_H */ diff --git a/lib/zstd/common/cpu.h b/lib/zstd/common/cpu.h index 0db7b42407ee..d8319a2bef4c 100644 --- a/lib/zstd/common/cpu.h +++ b/lib/zstd/common/cpu.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the diff --git a/lib/zstd/common/debug.c b/lib/zstd/common/debug.c index bb863c9ea616..8eb6aa9a3b20 100644 --- a/lib/zstd/common/debug.c +++ b/lib/zstd/common/debug.c @@ -1,7 +1,8 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* ****************************************************************** * debug * Part of FSE library - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * * You can contact the author at : * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy @@ -21,4 +22,10 @@ #include "debug.h" +#if (DEBUGLEVEL>=2) +/* We only use this when DEBUGLEVEL>=2, but we get -Werror=pedantic errors if a + * translation unit is empty. So remove this from Linux kernel builds, but + * otherwise just leave it in. + */ int g_debuglevel = DEBUGLEVEL; +#endif diff --git a/lib/zstd/common/debug.h b/lib/zstd/common/debug.h index 6dd88d1fbd02..c8a10281f112 100644 --- a/lib/zstd/common/debug.h +++ b/lib/zstd/common/debug.h @@ -1,7 +1,8 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* ****************************************************************** * debug * Part of FSE library - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * * You can contact the author at : * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy @@ -33,7 +34,6 @@ #define DEBUG_H_12987983217 - /* static assert is triggered at compile time, leaving no runtime artefact. * static assert only works with compile-time constants. * Also, this variant can only be used inside a function. */ @@ -82,20 +82,27 @@ extern int g_debuglevel; /* the variable is only declared, It's useful when enabling very verbose levels on selective conditions (such as position in src) */ -# define RAWLOG(l, ...) { \ - if (l<=g_debuglevel) { \ - ZSTD_DEBUG_PRINT(__VA_ARGS__); \ - } } -# define DEBUGLOG(l, ...) { \ - if (l<=g_debuglevel) { \ - ZSTD_DEBUG_PRINT(__FILE__ ": " __VA_ARGS__); \ - ZSTD_DEBUG_PRINT(" \n"); \ - } } +# define RAWLOG(l, ...) \ + do { \ + if (l<=g_debuglevel) { \ + ZSTD_DEBUG_PRINT(__VA_ARGS__); \ + } \ + } while (0) + +#define STRINGIFY(x) #x +#define TOSTRING(x) STRINGIFY(x) +#define LINE_AS_STRING TOSTRING(__LINE__) + +# define DEBUGLOG(l, ...) \ + do { \ + if (l<=g_debuglevel) { \ + ZSTD_DEBUG_PRINT(__FILE__ ":" LINE_AS_STRING ": " __VA_ARGS__); \ + ZSTD_DEBUG_PRINT(" \n"); \ + } \ + } while (0) #else -# define RAWLOG(l, ...) {} /* disabled */ -# define DEBUGLOG(l, ...) {} /* disabled */ +# define RAWLOG(l, ...) do { } while (0) /* disabled */ +# define DEBUGLOG(l, ...) do { } while (0) /* disabled */ #endif - - #endif /* DEBUG_H_12987983217 */ diff --git a/lib/zstd/common/entropy_common.c b/lib/zstd/common/entropy_common.c index fef67056f052..6cdd82233fb5 100644 --- a/lib/zstd/common/entropy_common.c +++ b/lib/zstd/common/entropy_common.c @@ -1,6 +1,7 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* ****************************************************************** * Common functions of New Generation Entropy library - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * * You can contact the author at : * - FSE+HUF source repository : https://github.com/Cyan4973/FiniteStateEntropy @@ -19,8 +20,8 @@ #include "error_private.h" /* ERR_*, ERROR */ #define FSE_STATIC_LINKING_ONLY /* FSE_MIN_TABLELOG */ #include "fse.h" -#define HUF_STATIC_LINKING_ONLY /* HUF_TABLELOG_ABSOLUTEMAX */ #include "huf.h" +#include "bits.h" /* ZSDT_highbit32, ZSTD_countTrailingZeros32 */ /*=== Version ===*/ @@ -38,23 +39,6 @@ const char* HUF_getErrorName(size_t code) { return ERR_getErrorName(code); } /*-************************************************************** * FSE NCount encoding-decoding ****************************************************************/ -static U32 FSE_ctz(U32 val) -{ - assert(val != 0); - { -# if (__GNUC__ >= 3) /* GCC Intrinsic */ - return __builtin_ctz(val); -# else /* Software version */ - U32 count = 0; - while ((val & 1) == 0) { - val >>= 1; - ++count; - } - return count; -# endif - } -} - FORCE_INLINE_TEMPLATE size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, const void* headerBuffer, size_t hbSize) @@ -102,7 +86,7 @@ size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigne * repeat. * Avoid UB by setting the high bit to 1. */ - int repeats = FSE_ctz(~bitStream | 0x80000000) >> 1; + int repeats = ZSTD_countTrailingZeros32(~bitStream | 0x80000000) >> 1; while (repeats >= 12) { charnum += 3 * 12; if (LIKELY(ip <= iend-7)) { @@ -113,7 +97,7 @@ size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigne ip = iend - 4; } bitStream = MEM_readLE32(ip) >> bitCount; - repeats = FSE_ctz(~bitStream | 0x80000000) >> 1; + repeats = ZSTD_countTrailingZeros32(~bitStream | 0x80000000) >> 1; } charnum += 3 * repeats; bitStream >>= 2 * repeats; @@ -178,7 +162,7 @@ size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigne * know that threshold > 1. */ if (remaining <= 1) break; - nbBits = BIT_highbit32(remaining) + 1; + nbBits = ZSTD_highbit32(remaining) + 1; threshold = 1 << (nbBits - 1); } if (charnum >= maxSV1) break; @@ -253,7 +237,7 @@ size_t HUF_readStats(BYTE* huffWeight, size_t hwSize, U32* rankStats, const void* src, size_t srcSize) { U32 wksp[HUF_READ_STATS_WORKSPACE_SIZE_U32]; - return HUF_readStats_wksp(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, wksp, sizeof(wksp), /* bmi2 */ 0); + return HUF_readStats_wksp(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, wksp, sizeof(wksp), /* flags */ 0); } FORCE_INLINE_TEMPLATE size_t @@ -301,14 +285,14 @@ HUF_readStats_body(BYTE* huffWeight, size_t hwSize, U32* rankStats, if (weightTotal == 0) return ERROR(corruption_detected); /* get last non-null symbol weight (implied, total must be 2^n) */ - { U32 const tableLog = BIT_highbit32(weightTotal) + 1; + { U32 const tableLog = ZSTD_highbit32(weightTotal) + 1; if (tableLog > HUF_TABLELOG_MAX) return ERROR(corruption_detected); *tableLogPtr = tableLog; /* determine last weight */ { U32 const total = 1 << tableLog; U32 const rest = total - weightTotal; - U32 const verif = 1 << BIT_highbit32(rest); - U32 const lastWeight = BIT_highbit32(rest) + 1; + U32 const verif = 1 << ZSTD_highbit32(rest); + U32 const lastWeight = ZSTD_highbit32(rest) + 1; if (verif != rest) return ERROR(corruption_detected); /* last value must be a clean power of 2 */ huffWeight[oSize] = (BYTE)lastWeight; rankStats[lastWeight]++; @@ -345,13 +329,13 @@ size_t HUF_readStats_wksp(BYTE* huffWeight, size_t hwSize, U32* rankStats, U32* nbSymbolsPtr, U32* tableLogPtr, const void* src, size_t srcSize, void* workSpace, size_t wkspSize, - int bmi2) + int flags) { #if DYNAMIC_BMI2 - if (bmi2) { + if (flags & HUF_flags_bmi2) { return HUF_readStats_body_bmi2(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize); } #endif - (void)bmi2; + (void)flags; return HUF_readStats_body_default(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize); } diff --git a/lib/zstd/common/error_private.c b/lib/zstd/common/error_private.c index 6d1135f8c373..6c3dbad838b6 100644 --- a/lib/zstd/common/error_private.c +++ b/lib/zstd/common/error_private.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -27,9 +28,11 @@ const char* ERR_getErrorString(ERR_enum code) case PREFIX(version_unsupported): return "Version not supported"; case PREFIX(frameParameter_unsupported): return "Unsupported frame parameter"; case PREFIX(frameParameter_windowTooLarge): return "Frame requires too much memory for decoding"; - case PREFIX(corruption_detected): return "Corrupted block detected"; + case PREFIX(corruption_detected): return "Data corruption detected"; case PREFIX(checksum_wrong): return "Restored data doesn't match checksum"; + case PREFIX(literals_headerWrong): return "Header of Literals' block doesn't respect format specification"; case PREFIX(parameter_unsupported): return "Unsupported parameter"; + case PREFIX(parameter_combination_unsupported): return "Unsupported combination of parameters"; case PREFIX(parameter_outOfBound): return "Parameter is out of bound"; case PREFIX(init_missing): return "Context should be init first"; case PREFIX(memory_allocation): return "Allocation error : not enough memory"; @@ -38,17 +41,23 @@ const char* ERR_getErrorString(ERR_enum code) case PREFIX(tableLog_tooLarge): return "tableLog requires too much memory : unsupported"; case PREFIX(maxSymbolValue_tooLarge): return "Unsupported max Symbol Value : too large"; case PREFIX(maxSymbolValue_tooSmall): return "Specified maxSymbolValue is too small"; + case PREFIX(cannotProduce_uncompressedBlock): return "This mode cannot generate an uncompressed block"; + case PREFIX(stabilityCondition_notRespected): return "pledged buffer stability condition is not respected"; case PREFIX(dictionary_corrupted): return "Dictionary is corrupted"; case PREFIX(dictionary_wrong): return "Dictionary mismatch"; case PREFIX(dictionaryCreation_failed): return "Cannot create Dictionary from provided samples"; case PREFIX(dstSize_tooSmall): return "Destination buffer is too small"; case PREFIX(srcSize_wrong): return "Src size is incorrect"; case PREFIX(dstBuffer_null): return "Operation on NULL destination buffer"; + case PREFIX(noForwardProgress_destFull): return "Operation made no progress over multiple calls, due to output buffer being full"; + case PREFIX(noForwardProgress_inputEmpty): return "Operation made no progress over multiple calls, due to input being empty"; /* following error codes are not stable and may be removed or changed in a future version */ case PREFIX(frameIndex_tooLarge): return "Frame index is too large"; case PREFIX(seekableIO): return "An I/O error occurred when reading/seeking"; case PREFIX(dstBuffer_wrong): return "Destination buffer is wrong"; case PREFIX(srcBuffer_wrong): return "Source buffer is wrong"; + case PREFIX(sequenceProducer_failed): return "Block-level external sequence producer returned an error code"; + case PREFIX(externalSequences_invalid): return "External sequences are not valid"; case PREFIX(maxCode): default: return notErrorCode; } diff --git a/lib/zstd/common/error_private.h b/lib/zstd/common/error_private.h index ca5101e542fa..08ee87b68cca 100644 --- a/lib/zstd/common/error_private.h +++ b/lib/zstd/common/error_private.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -13,8 +14,6 @@ #ifndef ERROR_H_MODULE #define ERROR_H_MODULE - - /* **************************************** * Dependencies ******************************************/ @@ -23,7 +22,6 @@ #include "debug.h" #include "zstd_deps.h" /* size_t */ - /* **************************************** * Compiler-specific ******************************************/ @@ -49,8 +47,13 @@ ERR_STATIC unsigned ERR_isError(size_t code) { return (code > ERROR(maxCode)); } ERR_STATIC ERR_enum ERR_getErrorCode(size_t code) { if (!ERR_isError(code)) return (ERR_enum)0; return (ERR_enum) (0-code); } /* check and forward error code */ -#define CHECK_V_F(e, f) size_t const e = f; if (ERR_isError(e)) return e -#define CHECK_F(f) { CHECK_V_F(_var_err__, f); } +#define CHECK_V_F(e, f) \ + size_t const e = f; \ + do { \ + if (ERR_isError(e)) \ + return e; \ + } while (0) +#define CHECK_F(f) do { CHECK_V_F(_var_err__, f); } while (0) /*-**************************************** @@ -84,10 +87,12 @@ void _force_has_format_string(const char *format, ...) { * We want to force this function invocation to be syntactically correct, but * we don't want to force runtime evaluation of its arguments. */ -#define _FORCE_HAS_FORMAT_STRING(...) \ - if (0) { \ - _force_has_format_string(__VA_ARGS__); \ - } +#define _FORCE_HAS_FORMAT_STRING(...) \ + do { \ + if (0) { \ + _force_has_format_string(__VA_ARGS__); \ + } \ + } while (0) #define ERR_QUOTE(str) #str @@ -98,48 +103,49 @@ void _force_has_format_string(const char *format, ...) { * In order to do that (particularly, printing the conditional that failed), * this can't just wrap RETURN_ERROR(). */ -#define RETURN_ERROR_IF(cond, err, ...) \ - if (cond) { \ - RAWLOG(3, "%s:%d: ERROR!: check %s failed, returning %s", \ - __FILE__, __LINE__, ERR_QUOTE(cond), ERR_QUOTE(ERROR(err))); \ - _FORCE_HAS_FORMAT_STRING(__VA_ARGS__); \ - RAWLOG(3, ": " __VA_ARGS__); \ - RAWLOG(3, "\n"); \ - return ERROR(err); \ - } +#define RETURN_ERROR_IF(cond, err, ...) \ + do { \ + if (cond) { \ + RAWLOG(3, "%s:%d: ERROR!: check %s failed, returning %s", \ + __FILE__, __LINE__, ERR_QUOTE(cond), ERR_QUOTE(ERROR(err))); \ + _FORCE_HAS_FORMAT_STRING(__VA_ARGS__); \ + RAWLOG(3, ": " __VA_ARGS__); \ + RAWLOG(3, "\n"); \ + return ERROR(err); \ + } \ + } while (0) /* * Unconditionally return the specified error. * * In debug modes, prints additional information. */ -#define RETURN_ERROR(err, ...) \ - do { \ - RAWLOG(3, "%s:%d: ERROR!: unconditional check failed, returning %s", \ - __FILE__, __LINE__, ERR_QUOTE(ERROR(err))); \ - _FORCE_HAS_FORMAT_STRING(__VA_ARGS__); \ - RAWLOG(3, ": " __VA_ARGS__); \ - RAWLOG(3, "\n"); \ - return ERROR(err); \ - } while(0); +#define RETURN_ERROR(err, ...) \ + do { \ + RAWLOG(3, "%s:%d: ERROR!: unconditional check failed, returning %s", \ + __FILE__, __LINE__, ERR_QUOTE(ERROR(err))); \ + _FORCE_HAS_FORMAT_STRING(__VA_ARGS__); \ + RAWLOG(3, ": " __VA_ARGS__); \ + RAWLOG(3, "\n"); \ + return ERROR(err); \ + } while(0) /* * If the provided expression evaluates to an error code, returns that error code. * * In debug modes, prints additional information. */ -#define FORWARD_IF_ERROR(err, ...) \ - do { \ - size_t const err_code = (err); \ - if (ERR_isError(err_code)) { \ - RAWLOG(3, "%s:%d: ERROR!: forwarding error in %s: %s", \ - __FILE__, __LINE__, ERR_QUOTE(err), ERR_getErrorName(err_code)); \ - _FORCE_HAS_FORMAT_STRING(__VA_ARGS__); \ - RAWLOG(3, ": " __VA_ARGS__); \ - RAWLOG(3, "\n"); \ - return err_code; \ - } \ - } while(0); - +#define FORWARD_IF_ERROR(err, ...) \ + do { \ + size_t const err_code = (err); \ + if (ERR_isError(err_code)) { \ + RAWLOG(3, "%s:%d: ERROR!: forwarding error in %s: %s", \ + __FILE__, __LINE__, ERR_QUOTE(err), ERR_getErrorName(err_code)); \ + _FORCE_HAS_FORMAT_STRING(__VA_ARGS__); \ + RAWLOG(3, ": " __VA_ARGS__); \ + RAWLOG(3, "\n"); \ + return err_code; \ + } \ + } while(0) #endif /* ERROR_H_MODULE */ diff --git a/lib/zstd/common/fse.h b/lib/zstd/common/fse.h index 4507043b2287..b36ce7a2a8c3 100644 --- a/lib/zstd/common/fse.h +++ b/lib/zstd/common/fse.h @@ -1,7 +1,8 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* ****************************************************************** * FSE : Finite State Entropy codec * Public Prototypes declaration - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * * You can contact the author at : * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy @@ -11,8 +12,6 @@ * in the COPYING file in the root directory of this source tree). * You may select, at your option, one of the above-listed licenses. ****************************************************************** */ - - #ifndef FSE_H #define FSE_H @@ -22,7 +21,6 @@ ******************************************/ #include "zstd_deps.h" /* size_t, ptrdiff_t */ - /*-***************************************** * FSE_PUBLIC_API : control library symbols visibility ******************************************/ @@ -50,34 +48,6 @@ FSE_PUBLIC_API unsigned FSE_versionNumber(void); /*< library version number; to be used when checking dll version */ -/*-**************************************** -* FSE simple functions -******************************************/ -/*! FSE_compress() : - Compress content of buffer 'src', of size 'srcSize', into destination buffer 'dst'. - 'dst' buffer must be already allocated. Compression runs faster is dstCapacity >= FSE_compressBound(srcSize). - @return : size of compressed data (<= dstCapacity). - Special values : if return == 0, srcData is not compressible => Nothing is stored within dst !!! - if return == 1, srcData is a single byte symbol * srcSize times. Use RLE compression instead. - if FSE_isError(return), compression failed (more details using FSE_getErrorName()) -*/ -FSE_PUBLIC_API size_t FSE_compress(void* dst, size_t dstCapacity, - const void* src, size_t srcSize); - -/*! FSE_decompress(): - Decompress FSE data from buffer 'cSrc', of size 'cSrcSize', - into already allocated destination buffer 'dst', of size 'dstCapacity'. - @return : size of regenerated data (<= maxDstSize), - or an error code, which can be tested using FSE_isError() . - - ** Important ** : FSE_decompress() does not decompress non-compressible nor RLE data !!! - Why ? : making this distinction requires a header. - Header management is intentionally delegated to the user layer, which can better manage special cases. -*/ -FSE_PUBLIC_API size_t FSE_decompress(void* dst, size_t dstCapacity, - const void* cSrc, size_t cSrcSize); - - /*-***************************************** * Tool functions ******************************************/ @@ -88,20 +58,6 @@ FSE_PUBLIC_API unsigned FSE_isError(size_t code); /* tells if a return FSE_PUBLIC_API const char* FSE_getErrorName(size_t code); /* provides error code string (useful for debugging) */ -/*-***************************************** -* FSE advanced functions -******************************************/ -/*! FSE_compress2() : - Same as FSE_compress(), but allows the selection of 'maxSymbolValue' and 'tableLog' - Both parameters can be defined as '0' to mean : use default value - @return : size of compressed data - Special values : if return == 0, srcData is not compressible => Nothing is stored within cSrc !!! - if return == 1, srcData is a single byte symbol * srcSize times. Use RLE compression. - if FSE_isError(return), it's an error code. -*/ -FSE_PUBLIC_API size_t FSE_compress2 (void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog); - - /*-***************************************** * FSE detailed API ******************************************/ @@ -161,8 +117,6 @@ FSE_PUBLIC_API size_t FSE_writeNCount (void* buffer, size_t bufferSize, /*! Constructor and Destructor of FSE_CTable. Note that FSE_CTable size depends on 'tableLog' and 'maxSymbolValue' */ typedef unsigned FSE_CTable; /* don't allocate that. It's only meant to be more restrictive than void* */ -FSE_PUBLIC_API FSE_CTable* FSE_createCTable (unsigned maxSymbolValue, unsigned tableLog); -FSE_PUBLIC_API void FSE_freeCTable (FSE_CTable* ct); /*! FSE_buildCTable(): Builds `ct`, which must be already allocated, using FSE_createCTable(). @@ -238,23 +192,7 @@ FSE_PUBLIC_API size_t FSE_readNCount_bmi2(short* normalizedCounter, unsigned* maxSymbolValuePtr, unsigned* tableLogPtr, const void* rBuffer, size_t rBuffSize, int bmi2); -/*! Constructor and Destructor of FSE_DTable. - Note that its size depends on 'tableLog' */ typedef unsigned FSE_DTable; /* don't allocate that. It's just a way to be more restrictive than void* */ -FSE_PUBLIC_API FSE_DTable* FSE_createDTable(unsigned tableLog); -FSE_PUBLIC_API void FSE_freeDTable(FSE_DTable* dt); - -/*! FSE_buildDTable(): - Builds 'dt', which must be already allocated, using FSE_createDTable(). - return : 0, or an errorCode, which can be tested using FSE_isError() */ -FSE_PUBLIC_API size_t FSE_buildDTable (FSE_DTable* dt, const short* normalizedCounter, unsigned maxSymbolValue, unsigned tableLog); - -/*! FSE_decompress_usingDTable(): - Decompress compressed source `cSrc` of size `cSrcSize` using `dt` - into `dst` which must be already allocated. - @return : size of regenerated data (necessarily <= `dstCapacity`), - or an errorCode, which can be tested using FSE_isError() */ -FSE_PUBLIC_API size_t FSE_decompress_usingDTable(void* dst, size_t dstCapacity, const void* cSrc, size_t cSrcSize, const FSE_DTable* dt); /*! Tutorial : @@ -286,13 +224,11 @@ If there is an error, the function will return an error code, which can be teste #endif /* FSE_H */ + #if !defined(FSE_H_FSE_STATIC_LINKING_ONLY) #define FSE_H_FSE_STATIC_LINKING_ONLY - -/* *** Dependency *** */ #include "bitstream.h" - /* ***************************************** * Static allocation *******************************************/ @@ -317,16 +253,6 @@ If there is an error, the function will return an error code, which can be teste unsigned FSE_optimalTableLog_internal(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue, unsigned minus); /*< same as FSE_optimalTableLog(), which used `minus==2` */ -/* FSE_compress_wksp() : - * Same as FSE_compress2(), but using an externally allocated scratch buffer (`workSpace`). - * FSE_COMPRESS_WKSP_SIZE_U32() provides the minimum size required for `workSpace` as a table of FSE_CTable. - */ -#define FSE_COMPRESS_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) ( FSE_CTABLE_SIZE_U32(maxTableLog, maxSymbolValue) + ((maxTableLog > 12) ? (1 << (maxTableLog - 2)) : 1024) ) -size_t FSE_compress_wksp (void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog, void* workSpace, size_t wkspSize); - -size_t FSE_buildCTable_raw (FSE_CTable* ct, unsigned nbBits); -/*< build a fake FSE_CTable, designed for a flat distribution, where each symbol uses nbBits */ - size_t FSE_buildCTable_rle (FSE_CTable* ct, unsigned char symbolValue); /*< build a fake FSE_CTable, designed to compress always the same symbolValue */ @@ -344,19 +270,11 @@ size_t FSE_buildCTable_wksp(FSE_CTable* ct, const short* normalizedCounter, unsi FSE_PUBLIC_API size_t FSE_buildDTable_wksp(FSE_DTable* dt, const short* normalizedCounter, unsigned maxSymbolValue, unsigned tableLog, void* workSpace, size_t wkspSize); /*< Same as FSE_buildDTable(), using an externally allocated `workspace` produced with `FSE_BUILD_DTABLE_WKSP_SIZE_U32(maxSymbolValue)` */ -size_t FSE_buildDTable_raw (FSE_DTable* dt, unsigned nbBits); -/*< build a fake FSE_DTable, designed to read a flat distribution where each symbol uses nbBits */ - -size_t FSE_buildDTable_rle (FSE_DTable* dt, unsigned char symbolValue); -/*< build a fake FSE_DTable, designed to always generate the same symbolValue */ - -#define FSE_DECOMPRESS_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) (FSE_DTABLE_SIZE_U32(maxTableLog) + FSE_BUILD_DTABLE_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) + (FSE_MAX_SYMBOL_VALUE + 1) / 2 + 1) +#define FSE_DECOMPRESS_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) (FSE_DTABLE_SIZE_U32(maxTableLog) + 1 + FSE_BUILD_DTABLE_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) + (FSE_MAX_SYMBOL_VALUE + 1) / 2 + 1) #define FSE_DECOMPRESS_WKSP_SIZE(maxTableLog, maxSymbolValue) (FSE_DECOMPRESS_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) * sizeof(unsigned)) -size_t FSE_decompress_wksp(void* dst, size_t dstCapacity, const void* cSrc, size_t cSrcSize, unsigned maxLog, void* workSpace, size_t wkspSize); -/*< same as FSE_decompress(), using an externally allocated `workSpace` produced with `FSE_DECOMPRESS_WKSP_SIZE_U32(maxLog, maxSymbolValue)` */ - size_t FSE_decompress_wksp_bmi2(void* dst, size_t dstCapacity, const void* cSrc, size_t cSrcSize, unsigned maxLog, void* workSpace, size_t wkspSize, int bmi2); -/*< Same as FSE_decompress_wksp() but with dynamic BMI2 support. Pass 1 if your CPU supports BMI2 or 0 if it doesn't. */ +/*< same as FSE_decompress(), using an externally allocated `workSpace` produced with `FSE_DECOMPRESS_WKSP_SIZE_U32(maxLog, maxSymbolValue)`. + * Set bmi2 to 1 if your CPU supports BMI2 or 0 if it doesn't */ typedef enum { FSE_repeat_none, /*< Cannot use the previous table */ @@ -539,20 +457,20 @@ MEM_STATIC void FSE_encodeSymbol(BIT_CStream_t* bitC, FSE_CState_t* statePtr, un FSE_symbolCompressionTransform const symbolTT = ((const FSE_symbolCompressionTransform*)(statePtr->symbolTT))[symbol]; const U16* const stateTable = (const U16*)(statePtr->stateTable); U32 const nbBitsOut = (U32)((statePtr->value + symbolTT.deltaNbBits) >> 16); - BIT_addBits(bitC, statePtr->value, nbBitsOut); + BIT_addBits(bitC, (BitContainerType)statePtr->value, nbBitsOut); statePtr->value = stateTable[ (statePtr->value >> nbBitsOut) + symbolTT.deltaFindState]; } MEM_STATIC void FSE_flushCState(BIT_CStream_t* bitC, const FSE_CState_t* statePtr) { - BIT_addBits(bitC, statePtr->value, statePtr->stateLog); + BIT_addBits(bitC, (BitContainerType)statePtr->value, statePtr->stateLog); BIT_flushBits(bitC); } /* FSE_getMaxNbBits() : * Approximate maximum cost of a symbol, in bits. - * Fractional get rounded up (i.e : a symbol with a normalized frequency of 3 gives the same result as a frequency of 2) + * Fractional get rounded up (i.e. a symbol with a normalized frequency of 3 gives the same result as a frequency of 2) * note 1 : assume symbolValue is valid (<= maxSymbolValue) * note 2 : if freq[symbolValue]==0, @return a fake cost of tableLog+1 bits */ MEM_STATIC U32 FSE_getMaxNbBits(const void* symbolTTPtr, U32 symbolValue) @@ -705,7 +623,4 @@ MEM_STATIC unsigned FSE_endOfDState(const FSE_DState_t* DStatePtr) #define FSE_TABLESTEP(tableSize) (((tableSize)>>1) + ((tableSize)>>3) + 3) - #endif /* FSE_STATIC_LINKING_ONLY */ - - diff --git a/lib/zstd/common/fse_decompress.c b/lib/zstd/common/fse_decompress.c index 8dcb8ca39767..15081d8dc607 100644 --- a/lib/zstd/common/fse_decompress.c +++ b/lib/zstd/common/fse_decompress.c @@ -1,6 +1,7 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* ****************************************************************** * FSE : Finite State Entropy decoder - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * * You can contact the author at : * - FSE source repository : https://github.com/Cyan4973/FiniteStateEntropy @@ -22,8 +23,8 @@ #define FSE_STATIC_LINKING_ONLY #include "fse.h" #include "error_private.h" -#define ZSTD_DEPS_NEED_MALLOC -#include "zstd_deps.h" +#include "zstd_deps.h" /* ZSTD_memcpy */ +#include "bits.h" /* ZSTD_highbit32 */ /* ************************************************************** @@ -55,19 +56,6 @@ #define FSE_FUNCTION_NAME(X,Y) FSE_CAT(X,Y) #define FSE_TYPE_NAME(X,Y) FSE_CAT(X,Y) - -/* Function templates */ -FSE_DTable* FSE_createDTable (unsigned tableLog) -{ - if (tableLog > FSE_TABLELOG_ABSOLUTE_MAX) tableLog = FSE_TABLELOG_ABSOLUTE_MAX; - return (FSE_DTable*)ZSTD_malloc( FSE_DTABLE_SIZE_U32(tableLog) * sizeof (U32) ); -} - -void FSE_freeDTable (FSE_DTable* dt) -{ - ZSTD_free(dt); -} - static size_t FSE_buildDTable_internal(FSE_DTable* dt, const short* normalizedCounter, unsigned maxSymbolValue, unsigned tableLog, void* workSpace, size_t wkspSize) { void* const tdPtr = dt+1; /* because *dt is unsigned, 32-bits aligned on 32-bits */ @@ -96,7 +84,7 @@ static size_t FSE_buildDTable_internal(FSE_DTable* dt, const short* normalizedCo symbolNext[s] = 1; } else { if (normalizedCounter[s] >= largeLimit) DTableH.fastMode=0; - symbolNext[s] = normalizedCounter[s]; + symbolNext[s] = (U16)normalizedCounter[s]; } } } ZSTD_memcpy(dt, &DTableH, sizeof(DTableH)); } @@ -111,8 +99,7 @@ static size_t FSE_buildDTable_internal(FSE_DTable* dt, const short* normalizedCo * all symbols have counts <= 8. We ensure we have 8 bytes at the end of * our buffer to handle the over-write. */ - { - U64 const add = 0x0101010101010101ull; + { U64 const add = 0x0101010101010101ull; size_t pos = 0; U64 sv = 0; U32 s; @@ -123,14 +110,13 @@ static size_t FSE_buildDTable_internal(FSE_DTable* dt, const short* normalizedCo for (i = 8; i < n; i += 8) { MEM_write64(spread + pos + i, sv); } - pos += n; - } - } + pos += (size_t)n; + } } /* Now we spread those positions across the table. - * The benefit of doing it in two stages is that we avoid the the + * The benefit of doing it in two stages is that we avoid the * variable size inner loop, which caused lots of branch misses. * Now we can run through all the positions without any branch misses. - * We unroll the loop twice, since that is what emperically worked best. + * We unroll the loop twice, since that is what empirically worked best. */ { size_t position = 0; @@ -166,7 +152,7 @@ static size_t FSE_buildDTable_internal(FSE_DTable* dt, const short* normalizedCo for (u=0; utableLog = 0; - DTableH->fastMode = 0; - - cell->newState = 0; - cell->symbol = symbolValue; - cell->nbBits = 0; - - return 0; -} - - -size_t FSE_buildDTable_raw (FSE_DTable* dt, unsigned nbBits) -{ - void* ptr = dt; - FSE_DTableHeader* const DTableH = (FSE_DTableHeader*)ptr; - void* dPtr = dt + 1; - FSE_decode_t* const dinfo = (FSE_decode_t*)dPtr; - const unsigned tableSize = 1 << nbBits; - const unsigned tableMask = tableSize - 1; - const unsigned maxSV1 = tableMask+1; - unsigned s; - - /* Sanity checks */ - if (nbBits < 1) return ERROR(GENERIC); /* min size */ - - /* Build Decoding Table */ - DTableH->tableLog = (U16)nbBits; - DTableH->fastMode = 1; - for (s=0; sfastMode; - - /* select fast mode (static) */ - if (fastMode) return FSE_decompress_usingDTable_generic(dst, originalSize, cSrc, cSrcSize, dt, 1); - return FSE_decompress_usingDTable_generic(dst, originalSize, cSrc, cSrcSize, dt, 0); -} - - -size_t FSE_decompress_wksp(void* dst, size_t dstCapacity, const void* cSrc, size_t cSrcSize, unsigned maxLog, void* workSpace, size_t wkspSize) -{ - return FSE_decompress_wksp_bmi2(dst, dstCapacity, cSrc, cSrcSize, maxLog, workSpace, wkspSize, /* bmi2 */ 0); + assert(op >= ostart); + return (size_t)(op-ostart); } typedef struct { short ncount[FSE_MAX_SYMBOL_VALUE + 1]; - FSE_DTable dtable[]; /* Dynamically sized */ } FSE_DecompressWksp; @@ -327,13 +252,18 @@ FORCE_INLINE_TEMPLATE size_t FSE_decompress_wksp_body( unsigned tableLog; unsigned maxSymbolValue = FSE_MAX_SYMBOL_VALUE; FSE_DecompressWksp* const wksp = (FSE_DecompressWksp*)workSpace; + size_t const dtablePos = sizeof(FSE_DecompressWksp) / sizeof(FSE_DTable); + FSE_DTable* const dtable = (FSE_DTable*)workSpace + dtablePos; - DEBUG_STATIC_ASSERT((FSE_MAX_SYMBOL_VALUE + 1) % 2 == 0); + FSE_STATIC_ASSERT((FSE_MAX_SYMBOL_VALUE + 1) % 2 == 0); if (wkspSize < sizeof(*wksp)) return ERROR(GENERIC); + /* correct offset to dtable depends on this property */ + FSE_STATIC_ASSERT(sizeof(FSE_DecompressWksp) % sizeof(FSE_DTable) == 0); + /* normal FSE decoding mode */ - { - size_t const NCountLength = FSE_readNCount_bmi2(wksp->ncount, &maxSymbolValue, &tableLog, istart, cSrcSize, bmi2); + { size_t const NCountLength = + FSE_readNCount_bmi2(wksp->ncount, &maxSymbolValue, &tableLog, istart, cSrcSize, bmi2); if (FSE_isError(NCountLength)) return NCountLength; if (tableLog > maxLog) return ERROR(tableLog_tooLarge); assert(NCountLength <= cSrcSize); @@ -342,19 +272,20 @@ FORCE_INLINE_TEMPLATE size_t FSE_decompress_wksp_body( } if (FSE_DECOMPRESS_WKSP_SIZE(tableLog, maxSymbolValue) > wkspSize) return ERROR(tableLog_tooLarge); - workSpace = wksp->dtable + FSE_DTABLE_SIZE_U32(tableLog); + assert(sizeof(*wksp) + FSE_DTABLE_SIZE(tableLog) <= wkspSize); + workSpace = (BYTE*)workSpace + sizeof(*wksp) + FSE_DTABLE_SIZE(tableLog); wkspSize -= sizeof(*wksp) + FSE_DTABLE_SIZE(tableLog); - CHECK_F( FSE_buildDTable_internal(wksp->dtable, wksp->ncount, maxSymbolValue, tableLog, workSpace, wkspSize) ); + CHECK_F( FSE_buildDTable_internal(dtable, wksp->ncount, maxSymbolValue, tableLog, workSpace, wkspSize) ); { - const void* ptr = wksp->dtable; + const void* ptr = dtable; const FSE_DTableHeader* DTableH = (const FSE_DTableHeader*)ptr; const U32 fastMode = DTableH->fastMode; /* select fast mode (static) */ - if (fastMode) return FSE_decompress_usingDTable_generic(dst, dstCapacity, ip, cSrcSize, wksp->dtable, 1); - return FSE_decompress_usingDTable_generic(dst, dstCapacity, ip, cSrcSize, wksp->dtable, 0); + if (fastMode) return FSE_decompress_usingDTable_generic(dst, dstCapacity, ip, cSrcSize, dtable, 1); + return FSE_decompress_usingDTable_generic(dst, dstCapacity, ip, cSrcSize, dtable, 0); } } @@ -382,9 +313,4 @@ size_t FSE_decompress_wksp_bmi2(void* dst, size_t dstCapacity, const void* cSrc, return FSE_decompress_wksp_body_default(dst, dstCapacity, cSrc, cSrcSize, maxLog, workSpace, wkspSize); } - -typedef FSE_DTable DTable_max_t[FSE_DTABLE_SIZE_U32(FSE_MAX_TABLELOG)]; - - - #endif /* FSE_COMMONDEFS_ONLY */ diff --git a/lib/zstd/common/huf.h b/lib/zstd/common/huf.h index 5042ff870308..49736dcd8f49 100644 --- a/lib/zstd/common/huf.h +++ b/lib/zstd/common/huf.h @@ -1,7 +1,8 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* ****************************************************************** * huff0 huffman codec, * part of Finite State Entropy library - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * * You can contact the author at : * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy @@ -12,105 +13,26 @@ * You may select, at your option, one of the above-listed licenses. ****************************************************************** */ - #ifndef HUF_H_298734234 #define HUF_H_298734234 /* *** Dependencies *** */ #include "zstd_deps.h" /* size_t */ - - -/* *** library symbols visibility *** */ -/* Note : when linking with -fvisibility=hidden on gcc, or by default on Visual, - * HUF symbols remain "private" (internal symbols for library only). - * Set macro FSE_DLL_EXPORT to 1 if you want HUF symbols visible on DLL interface */ -#if defined(FSE_DLL_EXPORT) && (FSE_DLL_EXPORT==1) && defined(__GNUC__) && (__GNUC__ >= 4) -# define HUF_PUBLIC_API __attribute__ ((visibility ("default"))) -#elif defined(FSE_DLL_EXPORT) && (FSE_DLL_EXPORT==1) /* Visual expected */ -# define HUF_PUBLIC_API __declspec(dllexport) -#elif defined(FSE_DLL_IMPORT) && (FSE_DLL_IMPORT==1) -# define HUF_PUBLIC_API __declspec(dllimport) /* not required, just to generate faster code (saves a function pointer load from IAT and an indirect jump) */ -#else -# define HUF_PUBLIC_API -#endif - - -/* ========================== */ -/* *** simple functions *** */ -/* ========================== */ - -/* HUF_compress() : - * Compress content from buffer 'src', of size 'srcSize', into buffer 'dst'. - * 'dst' buffer must be already allocated. - * Compression runs faster if `dstCapacity` >= HUF_compressBound(srcSize). - * `srcSize` must be <= `HUF_BLOCKSIZE_MAX` == 128 KB. - * @return : size of compressed data (<= `dstCapacity`). - * Special values : if return == 0, srcData is not compressible => Nothing is stored within dst !!! - * if HUF_isError(return), compression failed (more details using HUF_getErrorName()) - */ -HUF_PUBLIC_API size_t HUF_compress(void* dst, size_t dstCapacity, - const void* src, size_t srcSize); - -/* HUF_decompress() : - * Decompress HUF data from buffer 'cSrc', of size 'cSrcSize', - * into already allocated buffer 'dst', of minimum size 'dstSize'. - * `originalSize` : **must** be the ***exact*** size of original (uncompressed) data. - * Note : in contrast with FSE, HUF_decompress can regenerate - * RLE (cSrcSize==1) and uncompressed (cSrcSize==dstSize) data, - * because it knows size to regenerate (originalSize). - * @return : size of regenerated data (== originalSize), - * or an error code, which can be tested using HUF_isError() - */ -HUF_PUBLIC_API size_t HUF_decompress(void* dst, size_t originalSize, - const void* cSrc, size_t cSrcSize); - +#include "mem.h" /* U32 */ +#define FSE_STATIC_LINKING_ONLY +#include "fse.h" /* *** Tool functions *** */ -#define HUF_BLOCKSIZE_MAX (128 * 1024) /*< maximum input size for a single block compressed with HUF_compress */ -HUF_PUBLIC_API size_t HUF_compressBound(size_t size); /*< maximum compressed size (worst case) */ +#define HUF_BLOCKSIZE_MAX (128 * 1024) /*< maximum input size for a single block compressed with HUF_compress */ +size_t HUF_compressBound(size_t size); /*< maximum compressed size (worst case) */ /* Error Management */ -HUF_PUBLIC_API unsigned HUF_isError(size_t code); /*< tells if a return value is an error code */ -HUF_PUBLIC_API const char* HUF_getErrorName(size_t code); /*< provides error code string (useful for debugging) */ +unsigned HUF_isError(size_t code); /*< tells if a return value is an error code */ +const char* HUF_getErrorName(size_t code); /*< provides error code string (useful for debugging) */ -/* *** Advanced function *** */ - -/* HUF_compress2() : - * Same as HUF_compress(), but offers control over `maxSymbolValue` and `tableLog`. - * `maxSymbolValue` must be <= HUF_SYMBOLVALUE_MAX . - * `tableLog` must be `<= HUF_TABLELOG_MAX` . */ -HUF_PUBLIC_API size_t HUF_compress2 (void* dst, size_t dstCapacity, - const void* src, size_t srcSize, - unsigned maxSymbolValue, unsigned tableLog); - -/* HUF_compress4X_wksp() : - * Same as HUF_compress2(), but uses externally allocated `workSpace`. - * `workspace` must be at least as large as HUF_WORKSPACE_SIZE */ #define HUF_WORKSPACE_SIZE ((8 << 10) + 512 /* sorting scratch space */) #define HUF_WORKSPACE_SIZE_U64 (HUF_WORKSPACE_SIZE / sizeof(U64)) -HUF_PUBLIC_API size_t HUF_compress4X_wksp (void* dst, size_t dstCapacity, - const void* src, size_t srcSize, - unsigned maxSymbolValue, unsigned tableLog, - void* workSpace, size_t wkspSize); - -#endif /* HUF_H_298734234 */ - -/* ****************************************************************** - * WARNING !! - * The following section contains advanced and experimental definitions - * which shall never be used in the context of a dynamic library, - * because they are not guaranteed to remain stable in the future. - * Only consider them in association with static linking. - * *****************************************************************/ -#if !defined(HUF_H_HUF_STATIC_LINKING_ONLY) -#define HUF_H_HUF_STATIC_LINKING_ONLY - -/* *** Dependencies *** */ -#include "mem.h" /* U32 */ -#define FSE_STATIC_LINKING_ONLY -#include "fse.h" - /* *** Constants *** */ #define HUF_TABLELOG_MAX 12 /* max runtime value of tableLog (due to static allocation); can be modified up to HUF_TABLELOG_ABSOLUTEMAX */ @@ -151,25 +73,49 @@ typedef U32 HUF_DTable; /* **************************************** * Advanced decompression functions ******************************************/ -size_t HUF_decompress4X1 (void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< single-symbol decoder */ -#ifndef HUF_FORCE_DECOMPRESS_X1 -size_t HUF_decompress4X2 (void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< double-symbols decoder */ -#endif -size_t HUF_decompress4X_DCtx (HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< decodes RLE and uncompressed */ -size_t HUF_decompress4X_hufOnly(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< considers RLE and uncompressed as errors */ -size_t HUF_decompress4X_hufOnly_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize); /*< considers RLE and uncompressed as errors */ -size_t HUF_decompress4X1_DCtx(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< single-symbol decoder */ -size_t HUF_decompress4X1_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize); /*< single-symbol decoder */ -#ifndef HUF_FORCE_DECOMPRESS_X1 -size_t HUF_decompress4X2_DCtx(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< double-symbols decoder */ -size_t HUF_decompress4X2_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize); /*< double-symbols decoder */ -#endif +/* + * Huffman flags bitset. + * For all flags, 0 is the default value. + */ +typedef enum { + /* + * If compiled with DYNAMIC_BMI2: Set flag only if the CPU supports BMI2 at runtime. + * Otherwise: Ignored. + */ + HUF_flags_bmi2 = (1 << 0), + /* + * If set: Test possible table depths to find the one that produces the smallest header + encoded size. + * If unset: Use heuristic to find the table depth. + */ + HUF_flags_optimalDepth = (1 << 1), + /* + * If set: If the previous table can encode the input, always reuse the previous table. + * If unset: If the previous table can encode the input, reuse the previous table if it results in a smaller output. + */ + HUF_flags_preferRepeat = (1 << 2), + /* + * If set: Sample the input and check if the sample is uncompressible, if it is then don't attempt to compress. + * If unset: Always histogram the entire input. + */ + HUF_flags_suspectUncompressible = (1 << 3), + /* + * If set: Don't use assembly implementations + * If unset: Allow using assembly implementations + */ + HUF_flags_disableAsm = (1 << 4), + /* + * If set: Don't use the fast decoding loop, always use the fallback decoding loop. + * If unset: Use the fast decoding loop when possible. + */ + HUF_flags_disableFast = (1 << 5) +} HUF_flags_e; /* **************************************** * HUF detailed API * ****************************************/ +#define HUF_OPTIMAL_DEPTH_THRESHOLD ZSTD_btultra /*! HUF_compress() does the following: * 1. count symbol occurrence from source[] into table count[] using FSE_count() (exposed within "fse.h") @@ -182,12 +128,12 @@ size_t HUF_decompress4X2_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, * For example, it's possible to compress several blocks using the same 'CTable', * or to save and regenerate 'CTable' using external methods. */ -unsigned HUF_optimalTableLog(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue); -size_t HUF_buildCTable (HUF_CElt* CTable, const unsigned* count, unsigned maxSymbolValue, unsigned maxNbBits); /* @return : maxNbBits; CTable and count can overlap. In which case, CTable will overwrite count content */ -size_t HUF_writeCTable (void* dst, size_t maxDstSize, const HUF_CElt* CTable, unsigned maxSymbolValue, unsigned huffLog); +unsigned HUF_minTableLog(unsigned symbolCardinality); +unsigned HUF_cardinality(const unsigned* count, unsigned maxSymbolValue); +unsigned HUF_optimalTableLog(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue, void* workSpace, + size_t wkspSize, HUF_CElt* table, const unsigned* count, int flags); /* table is used as scratch space for building and testing tables, not a return value */ size_t HUF_writeCTable_wksp(void* dst, size_t maxDstSize, const HUF_CElt* CTable, unsigned maxSymbolValue, unsigned huffLog, void* workspace, size_t workspaceSize); -size_t HUF_compress4X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable); -size_t HUF_compress4X_usingCTable_bmi2(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int bmi2); +size_t HUF_compress4X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int flags); size_t HUF_estimateCompressedSize(const HUF_CElt* CTable, const unsigned* count, unsigned maxSymbolValue); int HUF_validateCTable(const HUF_CElt* CTable, const unsigned* count, unsigned maxSymbolValue); @@ -196,6 +142,7 @@ typedef enum { HUF_repeat_check, /*< Can use the previous table but it must be checked. Note : The previous table must have been constructed by HUF_compress{1, 4}X_repeat */ HUF_repeat_valid /*< Can use the previous table and it is assumed to be valid */ } HUF_repeat; + /* HUF_compress4X_repeat() : * Same as HUF_compress4X_wksp(), but considers using hufTable if *repeat != HUF_repeat_none. * If it uses hufTable it does not modify hufTable or repeat. @@ -206,13 +153,13 @@ size_t HUF_compress4X_repeat(void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog, void* workSpace, size_t wkspSize, /*< `workSpace` must be aligned on 4-bytes boundaries, `wkspSize` must be >= HUF_WORKSPACE_SIZE */ - HUF_CElt* hufTable, HUF_repeat* repeat, int preferRepeat, int bmi2, unsigned suspectUncompressible); + HUF_CElt* hufTable, HUF_repeat* repeat, int flags); /* HUF_buildCTable_wksp() : * Same as HUF_buildCTable(), but using externally allocated scratch buffer. * `workSpace` must be aligned on 4-bytes boundaries, and its size must be >= HUF_CTABLE_WORKSPACE_SIZE. */ -#define HUF_CTABLE_WORKSPACE_SIZE_U32 (2*HUF_SYMBOLVALUE_MAX +1 +1) +#define HUF_CTABLE_WORKSPACE_SIZE_U32 ((4 * (HUF_SYMBOLVALUE_MAX + 1)) + 192) #define HUF_CTABLE_WORKSPACE_SIZE (HUF_CTABLE_WORKSPACE_SIZE_U32 * sizeof(unsigned)) size_t HUF_buildCTable_wksp (HUF_CElt* tree, const unsigned* count, U32 maxSymbolValue, U32 maxNbBits, @@ -238,7 +185,7 @@ size_t HUF_readStats_wksp(BYTE* huffWeight, size_t hwSize, U32* rankStats, U32* nbSymbolsPtr, U32* tableLogPtr, const void* src, size_t srcSize, void* workspace, size_t wkspSize, - int bmi2); + int flags); /* HUF_readCTable() : * Loading a CTable saved with HUF_writeCTable() */ @@ -246,9 +193,22 @@ size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void /* HUF_getNbBitsFromCTable() : * Read nbBits from CTable symbolTable, for symbol `symbolValue` presumed <= HUF_SYMBOLVALUE_MAX - * Note 1 : is not inlined, as HUF_CElt definition is private */ + * Note 1 : If symbolValue > HUF_readCTableHeader(symbolTable).maxSymbolValue, returns 0 + * Note 2 : is not inlined, as HUF_CElt definition is private + */ U32 HUF_getNbBitsFromCTable(const HUF_CElt* symbolTable, U32 symbolValue); +typedef struct { + BYTE tableLog; + BYTE maxSymbolValue; + BYTE unused[sizeof(size_t) - 2]; +} HUF_CTableHeader; + +/* HUF_readCTableHeader() : + * @returns The header from the CTable specifying the tableLog and the maxSymbolValue. + */ +HUF_CTableHeader HUF_readCTableHeader(HUF_CElt const* ctable); + /* * HUF_decompress() does the following: * 1. select the decompression algorithm (X1, X2) based on pre-computed heuristics @@ -276,32 +236,12 @@ U32 HUF_selectDecoder (size_t dstSize, size_t cSrcSize); #define HUF_DECOMPRESS_WORKSPACE_SIZE ((2 << 10) + (1 << 9)) #define HUF_DECOMPRESS_WORKSPACE_SIZE_U32 (HUF_DECOMPRESS_WORKSPACE_SIZE / sizeof(U32)) -#ifndef HUF_FORCE_DECOMPRESS_X2 -size_t HUF_readDTableX1 (HUF_DTable* DTable, const void* src, size_t srcSize); -size_t HUF_readDTableX1_wksp (HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize); -#endif -#ifndef HUF_FORCE_DECOMPRESS_X1 -size_t HUF_readDTableX2 (HUF_DTable* DTable, const void* src, size_t srcSize); -size_t HUF_readDTableX2_wksp (HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize); -#endif - -size_t HUF_decompress4X_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable); -#ifndef HUF_FORCE_DECOMPRESS_X2 -size_t HUF_decompress4X1_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable); -#endif -#ifndef HUF_FORCE_DECOMPRESS_X1 -size_t HUF_decompress4X2_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable); -#endif - /* ====================== */ /* single stream variants */ /* ====================== */ -size_t HUF_compress1X (void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog); -size_t HUF_compress1X_wksp (void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog, void* workSpace, size_t wkspSize); /*< `workSpace` must be a table of at least HUF_WORKSPACE_SIZE_U64 U64 */ -size_t HUF_compress1X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable); -size_t HUF_compress1X_usingCTable_bmi2(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int bmi2); +size_t HUF_compress1X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int flags); /* HUF_compress1X_repeat() : * Same as HUF_compress1X_wksp(), but considers using hufTable if *repeat != HUF_repeat_none. * If it uses hufTable it does not modify hufTable or repeat. @@ -312,47 +252,27 @@ size_t HUF_compress1X_repeat(void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog, void* workSpace, size_t wkspSize, /*< `workSpace` must be aligned on 4-bytes boundaries, `wkspSize` must be >= HUF_WORKSPACE_SIZE */ - HUF_CElt* hufTable, HUF_repeat* repeat, int preferRepeat, int bmi2, unsigned suspectUncompressible); - -size_t HUF_decompress1X1 (void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /* single-symbol decoder */ -#ifndef HUF_FORCE_DECOMPRESS_X1 -size_t HUF_decompress1X2 (void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /* double-symbol decoder */ -#endif - -size_t HUF_decompress1X_DCtx (HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); -size_t HUF_decompress1X_DCtx_wksp (HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize); -#ifndef HUF_FORCE_DECOMPRESS_X2 -size_t HUF_decompress1X1_DCtx(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< single-symbol decoder */ -size_t HUF_decompress1X1_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize); /*< single-symbol decoder */ -#endif -#ifndef HUF_FORCE_DECOMPRESS_X1 -size_t HUF_decompress1X2_DCtx(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< double-symbols decoder */ -size_t HUF_decompress1X2_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize); /*< double-symbols decoder */ -#endif + HUF_CElt* hufTable, HUF_repeat* repeat, int flags); -size_t HUF_decompress1X_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable); /*< automatic selection of sing or double symbol decoder, based on DTable */ -#ifndef HUF_FORCE_DECOMPRESS_X2 -size_t HUF_decompress1X1_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable); -#endif +size_t HUF_decompress1X_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int flags); #ifndef HUF_FORCE_DECOMPRESS_X1 -size_t HUF_decompress1X2_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable); +size_t HUF_decompress1X2_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int flags); /*< double-symbols decoder */ #endif /* BMI2 variants. * If the CPU has BMI2 support, pass bmi2=1, otherwise pass bmi2=0. */ -size_t HUF_decompress1X_usingDTable_bmi2(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int bmi2); +size_t HUF_decompress1X_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int flags); #ifndef HUF_FORCE_DECOMPRESS_X2 -size_t HUF_decompress1X1_DCtx_wksp_bmi2(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int bmi2); +size_t HUF_decompress1X1_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int flags); #endif -size_t HUF_decompress4X_usingDTable_bmi2(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int bmi2); -size_t HUF_decompress4X_hufOnly_wksp_bmi2(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int bmi2); +size_t HUF_decompress4X_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int flags); +size_t HUF_decompress4X_hufOnly_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int flags); #ifndef HUF_FORCE_DECOMPRESS_X2 -size_t HUF_readDTableX1_wksp_bmi2(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize, int bmi2); +size_t HUF_readDTableX1_wksp(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize, int flags); #endif #ifndef HUF_FORCE_DECOMPRESS_X1 -size_t HUF_readDTableX2_wksp_bmi2(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize, int bmi2); +size_t HUF_readDTableX2_wksp(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize, int flags); #endif -#endif /* HUF_STATIC_LINKING_ONLY */ - +#endif /* HUF_H_298734234 */ diff --git a/lib/zstd/common/mem.h b/lib/zstd/common/mem.h index c22a2e69bf46..d9bd752fe17b 100644 --- a/lib/zstd/common/mem.h +++ b/lib/zstd/common/mem.h @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -24,6 +24,7 @@ /*-**************************************** * Compiler specifics ******************************************/ +#undef MEM_STATIC /* may be already defined from common/compiler.h */ #define MEM_STATIC static inline /*-************************************************************** diff --git a/lib/zstd/common/portability_macros.h b/lib/zstd/common/portability_macros.h index 0dde8bf56595..efae9465d57d 100644 --- a/lib/zstd/common/portability_macros.h +++ b/lib/zstd/common/portability_macros.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -12,7 +13,7 @@ #define ZSTD_PORTABILITY_MACROS_H /* - * This header file contains macro defintions to support portability. + * This header file contains macro definitions to support portability. * This header is shared between C and ASM code, so it MUST only * contain macro definitions. It MUST not contain any C code. * @@ -45,30 +46,37 @@ /* Mark the internal assembly functions as hidden */ #ifdef __ELF__ # define ZSTD_HIDE_ASM_FUNCTION(func) .hidden func +#elif defined(__APPLE__) +# define ZSTD_HIDE_ASM_FUNCTION(func) .private_extern func #else # define ZSTD_HIDE_ASM_FUNCTION(func) #endif +/* Compile time determination of BMI2 support */ + + /* Enable runtime BMI2 dispatch based on the CPU. - * Enabled for clang & gcc >=4.8 on x86 when BMI2 isn't enabled by default. + * Enabled for clang & gcc >= 11.4 on x86 when BMI2 isn't enabled by default. + * Disabled for gcc < 11.4 because of a segfault while compiling + * HUF_compress1X_usingCTable_internal_body(). */ #ifndef DYNAMIC_BMI2 - #if ((defined(__clang__) && __has_attribute(__target__)) \ +# if ((defined(__clang__) && __has_attribute(__target__)) \ || (defined(__GNUC__) \ - && (__GNUC__ >= 11))) \ - && (defined(__x86_64__) || defined(_M_X64)) \ + && (__GNUC__ >= 12 || (__GNUC__ == 11 && __GNUC_MINOR__ >= 4)))) \ + && (defined(__i386__) || defined(__x86_64__) || defined(_M_IX86) || defined(_M_X64)) \ && !defined(__BMI2__) - # define DYNAMIC_BMI2 1 - #else - # define DYNAMIC_BMI2 0 - #endif +# define DYNAMIC_BMI2 1 +# else +# define DYNAMIC_BMI2 0 +# endif #endif /* - * Only enable assembly for GNUC comptabile compilers, + * Only enable assembly for GNU C compatible compilers, * because other platforms may not support GAS assembly syntax. * - * Only enable assembly for Linux / MacOS, other platforms may + * Only enable assembly for Linux / MacOS / Win32, other platforms may * work, but they haven't been tested. This could likely be * extended to BSD systems. * @@ -90,4 +98,23 @@ */ #define ZSTD_ENABLE_ASM_X86_64_BMI2 0 +/* + * For x86 ELF targets, add .note.gnu.property section for Intel CET in + * assembly sources when CET is enabled. + * + * Additionally, any function that may be called indirectly must begin + * with ZSTD_CET_ENDBRANCH. + */ +#if defined(__ELF__) && (defined(__x86_64__) || defined(__i386__)) \ + && defined(__has_include) +# if __has_include() +# include +# define ZSTD_CET_ENDBRANCH _CET_ENDBR +# endif +#endif + +#ifndef ZSTD_CET_ENDBRANCH +# define ZSTD_CET_ENDBRANCH +#endif + #endif /* ZSTD_PORTABILITY_MACROS_H */ diff --git a/lib/zstd/common/zstd_common.c b/lib/zstd/common/zstd_common.c index 3d7e35b309b5..44b95b25344a 100644 --- a/lib/zstd/common/zstd_common.c +++ b/lib/zstd/common/zstd_common.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -14,7 +15,6 @@ * Dependencies ***************************************/ #define ZSTD_DEPS_NEED_MALLOC -#include "zstd_deps.h" /* ZSTD_malloc, ZSTD_calloc, ZSTD_free, ZSTD_memset */ #include "error_private.h" #include "zstd_internal.h" @@ -47,37 +47,3 @@ ZSTD_ErrorCode ZSTD_getErrorCode(size_t code) { return ERR_getErrorCode(code); } /*! ZSTD_getErrorString() : * provides error code string from enum */ const char* ZSTD_getErrorString(ZSTD_ErrorCode code) { return ERR_getErrorString(code); } - - - -/*=************************************************************** -* Custom allocator -****************************************************************/ -void* ZSTD_customMalloc(size_t size, ZSTD_customMem customMem) -{ - if (customMem.customAlloc) - return customMem.customAlloc(customMem.opaque, size); - return ZSTD_malloc(size); -} - -void* ZSTD_customCalloc(size_t size, ZSTD_customMem customMem) -{ - if (customMem.customAlloc) { - /* calloc implemented as malloc+memset; - * not as efficient as calloc, but next best guess for custom malloc */ - void* const ptr = customMem.customAlloc(customMem.opaque, size); - ZSTD_memset(ptr, 0, size); - return ptr; - } - return ZSTD_calloc(1, size); -} - -void ZSTD_customFree(void* ptr, ZSTD_customMem customMem) -{ - if (ptr!=NULL) { - if (customMem.customFree) - customMem.customFree(customMem.opaque, ptr); - else - ZSTD_free(ptr); - } -} diff --git a/lib/zstd/common/zstd_deps.h b/lib/zstd/common/zstd_deps.h index 2c34e8a33a1c..f931f7d0e294 100644 --- a/lib/zstd/common/zstd_deps.h +++ b/lib/zstd/common/zstd_deps.h @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -105,3 +105,17 @@ static uint64_t ZSTD_div64(uint64_t dividend, uint32_t divisor) { #endif /* ZSTD_DEPS_IO */ #endif /* ZSTD_DEPS_NEED_IO */ + +/* + * Only requested when MSAN is enabled. + * Need: + * intptr_t + */ +#ifdef ZSTD_DEPS_NEED_STDINT +#ifndef ZSTD_DEPS_STDINT +#define ZSTD_DEPS_STDINT + +/* intptr_t already provided by ZSTD_DEPS_COMMON */ + +#endif /* ZSTD_DEPS_STDINT */ +#endif /* ZSTD_DEPS_NEED_STDINT */ diff --git a/lib/zstd/common/zstd_internal.h b/lib/zstd/common/zstd_internal.h index 93305d9b41bb..52a79435caf6 100644 --- a/lib/zstd/common/zstd_internal.h +++ b/lib/zstd/common/zstd_internal.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -28,12 +29,10 @@ #include #define FSE_STATIC_LINKING_ONLY #include "fse.h" -#define HUF_STATIC_LINKING_ONLY #include "huf.h" #include /* XXH_reset, update, digest */ #define ZSTD_TRACE 0 - /* ---- static assert (debug) --- */ #define ZSTD_STATIC_ASSERT(c) DEBUG_STATIC_ASSERT(c) #define ZSTD_isError ERR_isError /* for inlining */ @@ -83,16 +82,17 @@ typedef enum { bt_raw, bt_rle, bt_compressed, bt_reserved } blockType_e; #define ZSTD_FRAMECHECKSUMSIZE 4 #define MIN_SEQUENCES_SIZE 1 /* nbSeq==0 */ -#define MIN_CBLOCK_SIZE (1 /*litCSize*/ + 1 /* RLE or RAW */ + MIN_SEQUENCES_SIZE /* nbSeq==0 */) /* for a non-null block */ +#define MIN_CBLOCK_SIZE (1 /*litCSize*/ + 1 /* RLE or RAW */) /* for a non-null block */ +#define MIN_LITERALS_FOR_4_STREAMS 6 -#define HufLog 12 -typedef enum { set_basic, set_rle, set_compressed, set_repeat } symbolEncodingType_e; +typedef enum { set_basic, set_rle, set_compressed, set_repeat } SymbolEncodingType_e; #define LONGNBSEQ 0x7F00 #define MINMATCH 3 #define Litbits 8 +#define LitHufLog 11 #define MaxLit ((1<= WILDCOPY_VECLEN || diff <= -WILDCOPY_VECLEN); @@ -225,12 +227,6 @@ void ZSTD_wildcopy(void* dst, const void* src, ptrdiff_t length, ZSTD_overlap_e * one COPY16() in the first call. Then, do two calls per loop since * at that point it is more likely to have a high trip count. */ -#ifdef __aarch64__ - do { - COPY16(op, ip); - } - while (op < oend); -#else ZSTD_copy16(op, ip); if (16 >= length) return; op += 16; @@ -240,7 +236,6 @@ void ZSTD_wildcopy(void* dst, const void* src, ptrdiff_t length, ZSTD_overlap_e COPY16(op, ip); } while (op < oend); -#endif } } @@ -273,62 +268,6 @@ typedef enum { /*-******************************************* * Private declarations *********************************************/ -typedef struct seqDef_s { - U32 offBase; /* offBase == Offset + ZSTD_REP_NUM, or repcode 1,2,3 */ - U16 litLength; - U16 mlBase; /* mlBase == matchLength - MINMATCH */ -} seqDef; - -/* Controls whether seqStore has a single "long" litLength or matchLength. See seqStore_t. */ -typedef enum { - ZSTD_llt_none = 0, /* no longLengthType */ - ZSTD_llt_literalLength = 1, /* represents a long literal */ - ZSTD_llt_matchLength = 2 /* represents a long match */ -} ZSTD_longLengthType_e; - -typedef struct { - seqDef* sequencesStart; - seqDef* sequences; /* ptr to end of sequences */ - BYTE* litStart; - BYTE* lit; /* ptr to end of literals */ - BYTE* llCode; - BYTE* mlCode; - BYTE* ofCode; - size_t maxNbSeq; - size_t maxNbLit; - - /* longLengthPos and longLengthType to allow us to represent either a single litLength or matchLength - * in the seqStore that has a value larger than U16 (if it exists). To do so, we increment - * the existing value of the litLength or matchLength by 0x10000. - */ - ZSTD_longLengthType_e longLengthType; - U32 longLengthPos; /* Index of the sequence to apply long length modification to */ -} seqStore_t; - -typedef struct { - U32 litLength; - U32 matchLength; -} ZSTD_sequenceLength; - -/* - * Returns the ZSTD_sequenceLength for the given sequences. It handles the decoding of long sequences - * indicated by longLengthPos and longLengthType, and adds MINMATCH back to matchLength. - */ -MEM_STATIC ZSTD_sequenceLength ZSTD_getSequenceLength(seqStore_t const* seqStore, seqDef const* seq) -{ - ZSTD_sequenceLength seqLen; - seqLen.litLength = seq->litLength; - seqLen.matchLength = seq->mlBase + MINMATCH; - if (seqStore->longLengthPos == (U32)(seq - seqStore->sequencesStart)) { - if (seqStore->longLengthType == ZSTD_llt_literalLength) { - seqLen.litLength += 0xFFFF; - } - if (seqStore->longLengthType == ZSTD_llt_matchLength) { - seqLen.matchLength += 0xFFFF; - } - } - return seqLen; -} /* * Contains the compressed frame size and an upper-bound for the decompressed frame size. @@ -337,74 +276,11 @@ MEM_STATIC ZSTD_sequenceLength ZSTD_getSequenceLength(seqStore_t const* seqStore * `decompressedBound != ZSTD_CONTENTSIZE_ERROR` */ typedef struct { + size_t nbBlocks; size_t compressedSize; unsigned long long decompressedBound; } ZSTD_frameSizeInfo; /* decompress & legacy */ -const seqStore_t* ZSTD_getSeqStore(const ZSTD_CCtx* ctx); /* compress & dictBuilder */ -void ZSTD_seqToCodes(const seqStore_t* seqStorePtr); /* compress, dictBuilder, decodeCorpus (shouldn't get its definition from here) */ - -/* custom memory allocation functions */ -void* ZSTD_customMalloc(size_t size, ZSTD_customMem customMem); -void* ZSTD_customCalloc(size_t size, ZSTD_customMem customMem); -void ZSTD_customFree(void* ptr, ZSTD_customMem customMem); - - -MEM_STATIC U32 ZSTD_highbit32(U32 val) /* compress, dictBuilder, decodeCorpus */ -{ - assert(val != 0); - { -# if (__GNUC__ >= 3) /* GCC Intrinsic */ - return __builtin_clz (val) ^ 31; -# else /* Software version */ - static const U32 DeBruijnClz[32] = { 0, 9, 1, 10, 13, 21, 2, 29, 11, 14, 16, 18, 22, 25, 3, 30, 8, 12, 20, 28, 15, 17, 24, 7, 19, 27, 23, 6, 26, 5, 4, 31 }; - U32 v = val; - v |= v >> 1; - v |= v >> 2; - v |= v >> 4; - v |= v >> 8; - v |= v >> 16; - return DeBruijnClz[(v * 0x07C4ACDDU) >> 27]; -# endif - } -} - -/* - * Counts the number of trailing zeros of a `size_t`. - * Most compilers should support CTZ as a builtin. A backup - * implementation is provided if the builtin isn't supported, but - * it may not be terribly efficient. - */ -MEM_STATIC unsigned ZSTD_countTrailingZeros(size_t val) -{ - if (MEM_64bits()) { -# if (__GNUC__ >= 4) - return __builtin_ctzll((U64)val); -# else - static const int DeBruijnBytePos[64] = { 0, 1, 2, 7, 3, 13, 8, 19, - 4, 25, 14, 28, 9, 34, 20, 56, - 5, 17, 26, 54, 15, 41, 29, 43, - 10, 31, 38, 35, 21, 45, 49, 57, - 63, 6, 12, 18, 24, 27, 33, 55, - 16, 53, 40, 42, 30, 37, 44, 48, - 62, 11, 23, 32, 52, 39, 36, 47, - 61, 22, 51, 46, 60, 50, 59, 58 }; - return DeBruijnBytePos[((U64)((val & -(long long)val) * 0x0218A392CDABBD3FULL)) >> 58]; -# endif - } else { /* 32 bits */ -# if (__GNUC__ >= 3) - return __builtin_ctz((U32)val); -# else - static const int DeBruijnBytePos[32] = { 0, 1, 28, 2, 29, 14, 24, 3, - 30, 22, 20, 15, 25, 17, 4, 8, - 31, 27, 13, 23, 21, 19, 16, 7, - 26, 12, 18, 6, 11, 5, 10, 9 }; - return DeBruijnBytePos[((U32)((val & -(S32)val) * 0x077CB531U)) >> 27]; -# endif - } -} - - /* ZSTD_invalidateRepCodes() : * ensures next compression will not use repcodes from previous block. * Note : only works with regular variant; @@ -420,13 +296,13 @@ typedef struct { /*! ZSTD_getcBlockSize() : * Provides the size of compressed block from block header `src` */ -/* Used by: decompress, fullbench (does not get its definition from here) */ +/* Used by: decompress, fullbench */ size_t ZSTD_getcBlockSize(const void* src, size_t srcSize, blockProperties_t* bpPtr); /*! ZSTD_decodeSeqHeaders() : * decode sequence header from src */ -/* Used by: decompress, fullbench (does not get its definition from here) */ +/* Used by: zstd_decompress_block, fullbench */ size_t ZSTD_decodeSeqHeaders(ZSTD_DCtx* dctx, int* nbSeqPtr, const void* src, size_t srcSize); @@ -439,5 +315,4 @@ MEM_STATIC int ZSTD_cpuSupportsBmi2(void) return ZSTD_cpuid_bmi1(cpuid) && ZSTD_cpuid_bmi2(cpuid); } - #endif /* ZSTD_CCOMMON_H_MODULE */ diff --git a/lib/zstd/compress/clevels.h b/lib/zstd/compress/clevels.h index d9a76112ec3a..6ab8be6532ef 100644 --- a/lib/zstd/compress/clevels.h +++ b/lib/zstd/compress/clevels.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the diff --git a/lib/zstd/compress/fse_compress.c b/lib/zstd/compress/fse_compress.c index ec5b1ca6d71a..44a3c10becf2 100644 --- a/lib/zstd/compress/fse_compress.c +++ b/lib/zstd/compress/fse_compress.c @@ -1,6 +1,7 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* ****************************************************************** * FSE : Finite State Entropy encoder - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * * You can contact the author at : * - FSE source repository : https://github.com/Cyan4973/FiniteStateEntropy @@ -25,7 +26,8 @@ #include "../common/error_private.h" #define ZSTD_DEPS_NEED_MALLOC #define ZSTD_DEPS_NEED_MATH64 -#include "../common/zstd_deps.h" /* ZSTD_malloc, ZSTD_free, ZSTD_memcpy, ZSTD_memset */ +#include "../common/zstd_deps.h" /* ZSTD_memset */ +#include "../common/bits.h" /* ZSTD_highbit32 */ /* ************************************************************** @@ -90,7 +92,7 @@ size_t FSE_buildCTable_wksp(FSE_CTable* ct, assert(tableLog < 16); /* required for threshold strategy to work */ /* For explanations on how to distribute symbol values over the table : - * http://fastcompression.blogspot.fr/2014/02/fse-distributing-symbol-values.html */ + * https://fastcompression.blogspot.fr/2014/02/fse-distributing-symbol-values.html */ #ifdef __clang_analyzer__ ZSTD_memset(tableSymbol, 0, sizeof(*tableSymbol) * tableSize); /* useless initialization, just to keep scan-build happy */ @@ -191,7 +193,7 @@ size_t FSE_buildCTable_wksp(FSE_CTable* ct, break; default : assert(normalizedCounter[s] > 1); - { U32 const maxBitsOut = tableLog - BIT_highbit32 ((U32)normalizedCounter[s]-1); + { U32 const maxBitsOut = tableLog - ZSTD_highbit32 ((U32)normalizedCounter[s]-1); U32 const minStatePlus = (U32)normalizedCounter[s] << maxBitsOut; symbolTT[s].deltaNbBits = (maxBitsOut << 16) - minStatePlus; symbolTT[s].deltaFindState = (int)(total - (unsigned)normalizedCounter[s]); @@ -224,8 +226,8 @@ size_t FSE_NCountWriteBound(unsigned maxSymbolValue, unsigned tableLog) size_t const maxHeaderSize = (((maxSymbolValue+1) * tableLog + 4 /* bitCount initialized at 4 */ + 2 /* first two symbols may use one additional bit each */) / 8) - + 1 /* round up to whole nb bytes */ - + 2 /* additional two bytes for bitstream flush */; + + 1 /* round up to whole nb bytes */ + + 2 /* additional two bytes for bitstream flush */; return maxSymbolValue ? maxHeaderSize : FSE_NCOUNTBOUND; /* maxSymbolValue==0 ? use default */ } @@ -254,7 +256,7 @@ FSE_writeNCount_generic (void* header, size_t headerBufferSize, /* Init */ remaining = tableSize+1; /* +1 for extra accuracy */ threshold = tableSize; - nbBits = tableLog+1; + nbBits = (int)tableLog+1; while ((symbol < alphabetSize) && (remaining>1)) { /* stops at 1 */ if (previousIs0) { @@ -273,7 +275,7 @@ FSE_writeNCount_generic (void* header, size_t headerBufferSize, } while (symbol >= start+3) { start+=3; - bitStream += 3 << bitCount; + bitStream += 3U << bitCount; bitCount += 2; } bitStream += (symbol-start) << bitCount; @@ -293,7 +295,7 @@ FSE_writeNCount_generic (void* header, size_t headerBufferSize, count++; /* +1 for extra accuracy */ if (count>=threshold) count += max; /* [0..max[ [max..threshold[ (...) [threshold+max 2*threshold[ */ - bitStream += count << bitCount; + bitStream += (U32)count << bitCount; bitCount += nbBits; bitCount -= (count>8); out+= (bitCount+7) /8; - return (out-ostart); + assert(out >= ostart); + return (size_t)(out-ostart); } @@ -342,21 +345,11 @@ size_t FSE_writeNCount (void* buffer, size_t bufferSize, * FSE Compression Code ****************************************************************/ -FSE_CTable* FSE_createCTable (unsigned maxSymbolValue, unsigned tableLog) -{ - size_t size; - if (tableLog > FSE_TABLELOG_ABSOLUTE_MAX) tableLog = FSE_TABLELOG_ABSOLUTE_MAX; - size = FSE_CTABLE_SIZE_U32 (tableLog, maxSymbolValue) * sizeof(U32); - return (FSE_CTable*)ZSTD_malloc(size); -} - -void FSE_freeCTable (FSE_CTable* ct) { ZSTD_free(ct); } - /* provides the minimum logSize to safely represent a distribution */ static unsigned FSE_minTableLog(size_t srcSize, unsigned maxSymbolValue) { - U32 minBitsSrc = BIT_highbit32((U32)(srcSize)) + 1; - U32 minBitsSymbols = BIT_highbit32(maxSymbolValue) + 2; + U32 minBitsSrc = ZSTD_highbit32((U32)(srcSize)) + 1; + U32 minBitsSymbols = ZSTD_highbit32(maxSymbolValue) + 2; U32 minBits = minBitsSrc < minBitsSymbols ? minBitsSrc : minBitsSymbols; assert(srcSize > 1); /* Not supported, RLE should be used instead */ return minBits; @@ -364,7 +357,7 @@ static unsigned FSE_minTableLog(size_t srcSize, unsigned maxSymbolValue) unsigned FSE_optimalTableLog_internal(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue, unsigned minus) { - U32 maxBitsSrc = BIT_highbit32((U32)(srcSize - 1)) - minus; + U32 maxBitsSrc = ZSTD_highbit32((U32)(srcSize - 1)) - minus; U32 tableLog = maxTableLog; U32 minBits = FSE_minTableLog(srcSize, maxSymbolValue); assert(srcSize > 1); /* Not supported, RLE should be used instead */ @@ -532,40 +525,6 @@ size_t FSE_normalizeCount (short* normalizedCounter, unsigned tableLog, return tableLog; } - -/* fake FSE_CTable, for raw (uncompressed) input */ -size_t FSE_buildCTable_raw (FSE_CTable* ct, unsigned nbBits) -{ - const unsigned tableSize = 1 << nbBits; - const unsigned tableMask = tableSize - 1; - const unsigned maxSymbolValue = tableMask; - void* const ptr = ct; - U16* const tableU16 = ( (U16*) ptr) + 2; - void* const FSCT = ((U32*)ptr) + 1 /* header */ + (tableSize>>1); /* assumption : tableLog >= 1 */ - FSE_symbolCompressionTransform* const symbolTT = (FSE_symbolCompressionTransform*) (FSCT); - unsigned s; - - /* Sanity checks */ - if (nbBits < 1) return ERROR(GENERIC); /* min size */ - - /* header */ - tableU16[-2] = (U16) nbBits; - tableU16[-1] = (U16) maxSymbolValue; - - /* Build table */ - for (s=0; s= 2 + +static size_t showU32(const U32* arr, size_t size) { - return FSE_optimalTableLog_internal(maxTableLog, srcSize, maxSymbolValue, 1); + size_t u; + for (u=0; u= sizeof(HUF_WriteCTableWksp)); + + assert(HUF_readCTableHeader(CTable).maxSymbolValue == maxSymbolValue); + assert(HUF_readCTableHeader(CTable).tableLog == huffLog); + /* check conditions */ if (workspaceSize < sizeof(HUF_WriteCTableWksp)) return ERROR(GENERIC); if (maxSymbolValue > HUF_SYMBOLVALUE_MAX) return ERROR(maxSymbolValue_tooLarge); @@ -204,16 +286,6 @@ size_t HUF_writeCTable_wksp(void* dst, size_t maxDstSize, return ((maxSymbolValue+1)/2) + 1; } -/*! HUF_writeCTable() : - `CTable` : Huffman tree to save, using huf representation. - @return : size of saved CTable */ -size_t HUF_writeCTable (void* dst, size_t maxDstSize, - const HUF_CElt* CTable, unsigned maxSymbolValue, unsigned huffLog) -{ - HUF_WriteCTableWksp wksp; - return HUF_writeCTable_wksp(dst, maxDstSize, CTable, maxSymbolValue, huffLog, &wksp, sizeof(wksp)); -} - size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void* src, size_t srcSize, unsigned* hasZeroWeights) { @@ -231,7 +303,9 @@ size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void if (tableLog > HUF_TABLELOG_MAX) return ERROR(tableLog_tooLarge); if (nbSymbols > *maxSymbolValuePtr+1) return ERROR(maxSymbolValue_tooSmall); - CTable[0] = tableLog; + *maxSymbolValuePtr = nbSymbols - 1; + + HUF_writeCTableHeader(CTable, tableLog, *maxSymbolValuePtr); /* Prepare base value per rank */ { U32 n, nextRankStart = 0; @@ -263,74 +337,71 @@ size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void { U32 n; for (n=0; n HUF_readCTableHeader(CTable).maxSymbolValue) + return 0; return (U32)HUF_getNbBits(ct[symbolValue]); } -typedef struct nodeElt_s { - U32 count; - U16 parent; - BYTE byte; - BYTE nbBits; -} nodeElt; - /* * HUF_setMaxHeight(): - * Enforces maxNbBits on the Huffman tree described in huffNode. + * Try to enforce @targetNbBits on the Huffman tree described in @huffNode. * - * It sets all nodes with nbBits > maxNbBits to be maxNbBits. Then it adjusts - * the tree to so that it is a valid canonical Huffman tree. + * It attempts to convert all nodes with nbBits > @targetNbBits + * to employ @targetNbBits instead. Then it adjusts the tree + * so that it remains a valid canonical Huffman tree. * * @pre The sum of the ranks of each symbol == 2^largestBits, * where largestBits == huffNode[lastNonNull].nbBits. * @post The sum of the ranks of each symbol == 2^largestBits, - * where largestBits is the return value <= maxNbBits. + * where largestBits is the return value (expected <= targetNbBits). * - * @param huffNode The Huffman tree modified in place to enforce maxNbBits. + * @param huffNode The Huffman tree modified in place to enforce targetNbBits. + * It's presumed sorted, from most frequent to rarest symbol. * @param lastNonNull The symbol with the lowest count in the Huffman tree. - * @param maxNbBits The maximum allowed number of bits, which the Huffman tree + * @param targetNbBits The allowed number of bits, which the Huffman tree * may not respect. After this function the Huffman tree will - * respect maxNbBits. - * @return The maximum number of bits of the Huffman tree after adjustment, - * necessarily no more than maxNbBits. + * respect targetNbBits. + * @return The maximum number of bits of the Huffman tree after adjustment. */ -static U32 HUF_setMaxHeight(nodeElt* huffNode, U32 lastNonNull, U32 maxNbBits) +static U32 HUF_setMaxHeight(nodeElt* huffNode, U32 lastNonNull, U32 targetNbBits) { const U32 largestBits = huffNode[lastNonNull].nbBits; - /* early exit : no elt > maxNbBits, so the tree is already valid. */ - if (largestBits <= maxNbBits) return largestBits; + /* early exit : no elt > targetNbBits, so the tree is already valid. */ + if (largestBits <= targetNbBits) return largestBits; + + DEBUGLOG(5, "HUF_setMaxHeight (targetNbBits = %u)", targetNbBits); /* there are several too large elements (at least >= 2) */ { int totalCost = 0; - const U32 baseCost = 1 << (largestBits - maxNbBits); + const U32 baseCost = 1 << (largestBits - targetNbBits); int n = (int)lastNonNull; - /* Adjust any ranks > maxNbBits to maxNbBits. + /* Adjust any ranks > targetNbBits to targetNbBits. * Compute totalCost, which is how far the sum of the ranks is * we are over 2^largestBits after adjust the offending ranks. */ - while (huffNode[n].nbBits > maxNbBits) { + while (huffNode[n].nbBits > targetNbBits) { totalCost += baseCost - (1 << (largestBits - huffNode[n].nbBits)); - huffNode[n].nbBits = (BYTE)maxNbBits; + huffNode[n].nbBits = (BYTE)targetNbBits; n--; } - /* n stops at huffNode[n].nbBits <= maxNbBits */ - assert(huffNode[n].nbBits <= maxNbBits); - /* n end at index of smallest symbol using < maxNbBits */ - while (huffNode[n].nbBits == maxNbBits) --n; + /* n stops at huffNode[n].nbBits <= targetNbBits */ + assert(huffNode[n].nbBits <= targetNbBits); + /* n end at index of smallest symbol using < targetNbBits */ + while (huffNode[n].nbBits == targetNbBits) --n; - /* renorm totalCost from 2^largestBits to 2^maxNbBits + /* renorm totalCost from 2^largestBits to 2^targetNbBits * note : totalCost is necessarily a multiple of baseCost */ - assert((totalCost & (baseCost - 1)) == 0); - totalCost >>= (largestBits - maxNbBits); + assert(((U32)totalCost & (baseCost - 1)) == 0); + totalCost >>= (largestBits - targetNbBits); assert(totalCost > 0); /* repay normalized cost */ @@ -339,19 +410,19 @@ static U32 HUF_setMaxHeight(nodeElt* huffNode, U32 lastNonNull, U32 maxNbBits) /* Get pos of last (smallest = lowest cum. count) symbol per rank */ ZSTD_memset(rankLast, 0xF0, sizeof(rankLast)); - { U32 currentNbBits = maxNbBits; + { U32 currentNbBits = targetNbBits; int pos; for (pos=n ; pos >= 0; pos--) { if (huffNode[pos].nbBits >= currentNbBits) continue; - currentNbBits = huffNode[pos].nbBits; /* < maxNbBits */ - rankLast[maxNbBits-currentNbBits] = (U32)pos; + currentNbBits = huffNode[pos].nbBits; /* < targetNbBits */ + rankLast[targetNbBits-currentNbBits] = (U32)pos; } } while (totalCost > 0) { /* Try to reduce the next power of 2 above totalCost because we * gain back half the rank. */ - U32 nBitsToDecrease = BIT_highbit32((U32)totalCost) + 1; + U32 nBitsToDecrease = ZSTD_highbit32((U32)totalCost) + 1; for ( ; nBitsToDecrease > 1; nBitsToDecrease--) { U32 const highPos = rankLast[nBitsToDecrease]; U32 const lowPos = rankLast[nBitsToDecrease-1]; @@ -391,7 +462,7 @@ static U32 HUF_setMaxHeight(nodeElt* huffNode, U32 lastNonNull, U32 maxNbBits) rankLast[nBitsToDecrease] = noSymbol; else { rankLast[nBitsToDecrease]--; - if (huffNode[rankLast[nBitsToDecrease]].nbBits != maxNbBits-nBitsToDecrease) + if (huffNode[rankLast[nBitsToDecrease]].nbBits != targetNbBits-nBitsToDecrease) rankLast[nBitsToDecrease] = noSymbol; /* this rank is now empty */ } } /* while (totalCost > 0) */ @@ -403,11 +474,11 @@ static U32 HUF_setMaxHeight(nodeElt* huffNode, U32 lastNonNull, U32 maxNbBits) * TODO. */ while (totalCost < 0) { /* Sometimes, cost correction overshoot */ - /* special case : no rank 1 symbol (using maxNbBits-1); - * let's create one from largest rank 0 (using maxNbBits). + /* special case : no rank 1 symbol (using targetNbBits-1); + * let's create one from largest rank 0 (using targetNbBits). */ if (rankLast[1] == noSymbol) { - while (huffNode[n].nbBits == maxNbBits) n--; + while (huffNode[n].nbBits == targetNbBits) n--; huffNode[n+1].nbBits--; assert(n >= 0); rankLast[1] = (U32)(n+1); @@ -421,7 +492,7 @@ static U32 HUF_setMaxHeight(nodeElt* huffNode, U32 lastNonNull, U32 maxNbBits) } /* repay normalized cost */ } /* there are several too large elements (at least >= 2) */ - return maxNbBits; + return targetNbBits; } typedef struct { @@ -429,7 +500,7 @@ typedef struct { U16 curr; } rankPos; -typedef nodeElt huffNodeTable[HUF_CTABLE_WORKSPACE_SIZE_U32]; +typedef nodeElt huffNodeTable[2 * (HUF_SYMBOLVALUE_MAX + 1)]; /* Number of buckets available for HUF_sort() */ #define RANK_POSITION_TABLE_SIZE 192 @@ -448,8 +519,8 @@ typedef struct { * Let buckets 166 to 192 represent all remaining counts up to RANK_POSITION_MAX_COUNT_LOG using log2 bucketing. */ #define RANK_POSITION_MAX_COUNT_LOG 32 -#define RANK_POSITION_LOG_BUCKETS_BEGIN (RANK_POSITION_TABLE_SIZE - 1) - RANK_POSITION_MAX_COUNT_LOG - 1 /* == 158 */ -#define RANK_POSITION_DISTINCT_COUNT_CUTOFF RANK_POSITION_LOG_BUCKETS_BEGIN + BIT_highbit32(RANK_POSITION_LOG_BUCKETS_BEGIN) /* == 166 */ +#define RANK_POSITION_LOG_BUCKETS_BEGIN ((RANK_POSITION_TABLE_SIZE - 1) - RANK_POSITION_MAX_COUNT_LOG - 1 /* == 158 */) +#define RANK_POSITION_DISTINCT_COUNT_CUTOFF (RANK_POSITION_LOG_BUCKETS_BEGIN + ZSTD_highbit32(RANK_POSITION_LOG_BUCKETS_BEGIN) /* == 166 */) /* Return the appropriate bucket index for a given count. See definition of * RANK_POSITION_DISTINCT_COUNT_CUTOFF for explanation of bucketing strategy. @@ -457,7 +528,7 @@ typedef struct { static U32 HUF_getIndex(U32 const count) { return (count < RANK_POSITION_DISTINCT_COUNT_CUTOFF) ? count - : BIT_highbit32(count) + RANK_POSITION_LOG_BUCKETS_BEGIN; + : ZSTD_highbit32(count) + RANK_POSITION_LOG_BUCKETS_BEGIN; } /* Helper swap function for HUF_quickSortPartition() */ @@ -580,7 +651,7 @@ static void HUF_sort(nodeElt huffNode[], const unsigned count[], U32 const maxSy /* Sort each bucket. */ for (n = RANK_POSITION_DISTINCT_COUNT_CUTOFF; n < RANK_POSITION_TABLE_SIZE - 1; ++n) { - U32 const bucketSize = rankPosition[n].curr-rankPosition[n].base; + int const bucketSize = rankPosition[n].curr - rankPosition[n].base; U32 const bucketStartIdx = rankPosition[n].base; if (bucketSize > 1) { assert(bucketStartIdx < maxSymbolValue1); @@ -591,6 +662,7 @@ static void HUF_sort(nodeElt huffNode[], const unsigned count[], U32 const maxSy assert(HUF_isSorted(huffNode, maxSymbolValue1)); } + /* HUF_buildCTable_wksp() : * Same as HUF_buildCTable(), but using externally allocated scratch buffer. * `workSpace` must be aligned on 4-bytes boundaries, and be at least as large as sizeof(HUF_buildCTable_wksp_tables). @@ -611,6 +683,7 @@ static int HUF_buildTree(nodeElt* huffNode, U32 maxSymbolValue) int lowS, lowN; int nodeNb = STARTNODE; int n, nodeRoot; + DEBUGLOG(5, "HUF_buildTree (alphabet size = %u)", maxSymbolValue + 1); /* init for parents */ nonNullRank = (int)maxSymbolValue; while(huffNode[nonNullRank].count == 0) nonNullRank--; @@ -637,6 +710,8 @@ static int HUF_buildTree(nodeElt* huffNode, U32 maxSymbolValue) for (n=0; n<=nonNullRank; n++) huffNode[n].nbBits = huffNode[ huffNode[n].parent ].nbBits + 1; + DEBUGLOG(6, "Initial distribution of bits completed (%zu sorted symbols)", showHNodeBits(huffNode, maxSymbolValue+1)); + return nonNullRank; } @@ -671,31 +746,40 @@ static void HUF_buildCTableFromTree(HUF_CElt* CTable, nodeElt const* huffNode, i HUF_setNbBits(ct + huffNode[n].byte, huffNode[n].nbBits); /* push nbBits per symbol, symbol order */ for (n=0; nhuffNodeTbl; nodeElt* const huffNode = huffNode0+1; int nonNullRank; + HUF_STATIC_ASSERT(HUF_CTABLE_WORKSPACE_SIZE == sizeof(HUF_buildCTable_wksp_tables)); + + DEBUGLOG(5, "HUF_buildCTable_wksp (alphabet size = %u)", maxSymbolValue+1); + /* safety checks */ if (wkspSize < sizeof(HUF_buildCTable_wksp_tables)) - return ERROR(workSpace_tooSmall); + return ERROR(workSpace_tooSmall); if (maxNbBits == 0) maxNbBits = HUF_TABLELOG_DEFAULT; if (maxSymbolValue > HUF_SYMBOLVALUE_MAX) - return ERROR(maxSymbolValue_tooLarge); + return ERROR(maxSymbolValue_tooLarge); ZSTD_memset(huffNode0, 0, sizeof(huffNodeTable)); /* sort, decreasing order */ HUF_sort(huffNode, count, maxSymbolValue, wksp_tables->rankPosition); + DEBUGLOG(6, "sorted symbols completed (%zu symbols)", showHNodeSymbols(huffNode, maxSymbolValue+1)); /* build tree */ nonNullRank = HUF_buildTree(huffNode, maxSymbolValue); - /* enforce maxTableLog */ + /* determine and enforce maxTableLog */ maxNbBits = HUF_setMaxHeight(huffNode, (U32)nonNullRank, maxNbBits); if (maxNbBits > HUF_TABLELOG_MAX) return ERROR(GENERIC); /* check fit into table */ @@ -716,13 +800,20 @@ size_t HUF_estimateCompressedSize(const HUF_CElt* CTable, const unsigned* count, } int HUF_validateCTable(const HUF_CElt* CTable, const unsigned* count, unsigned maxSymbolValue) { - HUF_CElt const* ct = CTable + 1; - int bad = 0; - int s; - for (s = 0; s <= (int)maxSymbolValue; ++s) { - bad |= (count[s] != 0) & (HUF_getNbBits(ct[s]) == 0); - } - return !bad; + HUF_CTableHeader header = HUF_readCTableHeader(CTable); + HUF_CElt const* ct = CTable + 1; + int bad = 0; + int s; + + assert(header.tableLog <= HUF_TABLELOG_ABSOLUTEMAX); + + if (header.maxSymbolValue < maxSymbolValue) + return 0; + + for (s = 0; s <= (int)maxSymbolValue; ++s) { + bad |= (count[s] != 0) & (HUF_getNbBits(ct[s]) == 0); + } + return !bad; } size_t HUF_compressBound(size_t size) { return HUF_COMPRESSBOUND(size); } @@ -804,7 +895,7 @@ FORCE_INLINE_TEMPLATE void HUF_addBits(HUF_CStream_t* bitC, HUF_CElt elt, int id #if DEBUGLEVEL >= 1 { size_t const nbBits = HUF_getNbBits(elt); - size_t const dirtyBits = nbBits == 0 ? 0 : BIT_highbit32((U32)nbBits) + 1; + size_t const dirtyBits = nbBits == 0 ? 0 : ZSTD_highbit32((U32)nbBits) + 1; (void)dirtyBits; /* Middle bits are 0. */ assert(((elt >> dirtyBits) << (dirtyBits + nbBits)) == 0); @@ -884,7 +975,7 @@ static size_t HUF_closeCStream(HUF_CStream_t* bitC) { size_t const nbBits = bitC->bitPos[0] & 0xFF; if (bitC->ptr >= bitC->endPtr) return 0; /* overflow detected */ - return (bitC->ptr - bitC->startPtr) + (nbBits > 0); + return (size_t)(bitC->ptr - bitC->startPtr) + (nbBits > 0); } } @@ -964,17 +1055,17 @@ HUF_compress1X_usingCTable_internal_body(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable) { - U32 const tableLog = (U32)CTable[0]; + U32 const tableLog = HUF_readCTableHeader(CTable).tableLog; HUF_CElt const* ct = CTable + 1; const BYTE* ip = (const BYTE*) src; BYTE* const ostart = (BYTE*)dst; BYTE* const oend = ostart + dstSize; - BYTE* op = ostart; HUF_CStream_t bitC; /* init */ if (dstSize < 8) return 0; /* not enough space to compress */ - { size_t const initErr = HUF_initCStream(&bitC, op, (size_t)(oend-op)); + { BYTE* op = ostart; + size_t const initErr = HUF_initCStream(&bitC, op, (size_t)(oend-op)); if (HUF_isError(initErr)) return 0; } if (dstSize < HUF_tightCompressBound(srcSize, (size_t)tableLog) || tableLog > 11) @@ -1045,9 +1136,9 @@ HUF_compress1X_usingCTable_internal_default(void* dst, size_t dstSize, static size_t HUF_compress1X_usingCTable_internal(void* dst, size_t dstSize, const void* src, size_t srcSize, - const HUF_CElt* CTable, const int bmi2) + const HUF_CElt* CTable, const int flags) { - if (bmi2) { + if (flags & HUF_flags_bmi2) { return HUF_compress1X_usingCTable_internal_bmi2(dst, dstSize, src, srcSize, CTable); } return HUF_compress1X_usingCTable_internal_default(dst, dstSize, src, srcSize, CTable); @@ -1058,28 +1149,23 @@ HUF_compress1X_usingCTable_internal(void* dst, size_t dstSize, static size_t HUF_compress1X_usingCTable_internal(void* dst, size_t dstSize, const void* src, size_t srcSize, - const HUF_CElt* CTable, const int bmi2) + const HUF_CElt* CTable, const int flags) { - (void)bmi2; + (void)flags; return HUF_compress1X_usingCTable_internal_body(dst, dstSize, src, srcSize, CTable); } #endif -size_t HUF_compress1X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable) +size_t HUF_compress1X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int flags) { - return HUF_compress1X_usingCTable_bmi2(dst, dstSize, src, srcSize, CTable, /* bmi2 */ 0); -} - -size_t HUF_compress1X_usingCTable_bmi2(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int bmi2) -{ - return HUF_compress1X_usingCTable_internal(dst, dstSize, src, srcSize, CTable, bmi2); + return HUF_compress1X_usingCTable_internal(dst, dstSize, src, srcSize, CTable, flags); } static size_t HUF_compress4X_usingCTable_internal(void* dst, size_t dstSize, const void* src, size_t srcSize, - const HUF_CElt* CTable, int bmi2) + const HUF_CElt* CTable, int flags) { size_t const segmentSize = (srcSize+3)/4; /* first 3 segments */ const BYTE* ip = (const BYTE*) src; @@ -1093,7 +1179,7 @@ HUF_compress4X_usingCTable_internal(void* dst, size_t dstSize, op += 6; /* jumpTable */ assert(op <= oend); - { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, segmentSize, CTable, bmi2) ); + { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, segmentSize, CTable, flags) ); if (cSize == 0 || cSize > 65535) return 0; MEM_writeLE16(ostart, (U16)cSize); op += cSize; @@ -1101,7 +1187,7 @@ HUF_compress4X_usingCTable_internal(void* dst, size_t dstSize, ip += segmentSize; assert(op <= oend); - { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, segmentSize, CTable, bmi2) ); + { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, segmentSize, CTable, flags) ); if (cSize == 0 || cSize > 65535) return 0; MEM_writeLE16(ostart+2, (U16)cSize); op += cSize; @@ -1109,7 +1195,7 @@ HUF_compress4X_usingCTable_internal(void* dst, size_t dstSize, ip += segmentSize; assert(op <= oend); - { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, segmentSize, CTable, bmi2) ); + { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, segmentSize, CTable, flags) ); if (cSize == 0 || cSize > 65535) return 0; MEM_writeLE16(ostart+4, (U16)cSize); op += cSize; @@ -1118,7 +1204,7 @@ HUF_compress4X_usingCTable_internal(void* dst, size_t dstSize, ip += segmentSize; assert(op <= oend); assert(ip <= iend); - { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, (size_t)(iend-ip), CTable, bmi2) ); + { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, (size_t)(iend-ip), CTable, flags) ); if (cSize == 0 || cSize > 65535) return 0; op += cSize; } @@ -1126,14 +1212,9 @@ HUF_compress4X_usingCTable_internal(void* dst, size_t dstSize, return (size_t)(op-ostart); } -size_t HUF_compress4X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable) -{ - return HUF_compress4X_usingCTable_bmi2(dst, dstSize, src, srcSize, CTable, /* bmi2 */ 0); -} - -size_t HUF_compress4X_usingCTable_bmi2(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int bmi2) +size_t HUF_compress4X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int flags) { - return HUF_compress4X_usingCTable_internal(dst, dstSize, src, srcSize, CTable, bmi2); + return HUF_compress4X_usingCTable_internal(dst, dstSize, src, srcSize, CTable, flags); } typedef enum { HUF_singleStream, HUF_fourStreams } HUF_nbStreams_e; @@ -1141,11 +1222,11 @@ typedef enum { HUF_singleStream, HUF_fourStreams } HUF_nbStreams_e; static size_t HUF_compressCTable_internal( BYTE* const ostart, BYTE* op, BYTE* const oend, const void* src, size_t srcSize, - HUF_nbStreams_e nbStreams, const HUF_CElt* CTable, const int bmi2) + HUF_nbStreams_e nbStreams, const HUF_CElt* CTable, const int flags) { size_t const cSize = (nbStreams==HUF_singleStream) ? - HUF_compress1X_usingCTable_internal(op, (size_t)(oend - op), src, srcSize, CTable, bmi2) : - HUF_compress4X_usingCTable_internal(op, (size_t)(oend - op), src, srcSize, CTable, bmi2); + HUF_compress1X_usingCTable_internal(op, (size_t)(oend - op), src, srcSize, CTable, flags) : + HUF_compress4X_usingCTable_internal(op, (size_t)(oend - op), src, srcSize, CTable, flags); if (HUF_isError(cSize)) { return cSize; } if (cSize==0) { return 0; } /* uncompressible */ op += cSize; @@ -1168,6 +1249,81 @@ typedef struct { #define SUSPECT_INCOMPRESSIBLE_SAMPLE_SIZE 4096 #define SUSPECT_INCOMPRESSIBLE_SAMPLE_RATIO 10 /* Must be >= 2 */ +unsigned HUF_cardinality(const unsigned* count, unsigned maxSymbolValue) +{ + unsigned cardinality = 0; + unsigned i; + + for (i = 0; i < maxSymbolValue + 1; i++) { + if (count[i] != 0) cardinality += 1; + } + + return cardinality; +} + +unsigned HUF_minTableLog(unsigned symbolCardinality) +{ + U32 minBitsSymbols = ZSTD_highbit32(symbolCardinality) + 1; + return minBitsSymbols; +} + +unsigned HUF_optimalTableLog( + unsigned maxTableLog, + size_t srcSize, + unsigned maxSymbolValue, + void* workSpace, size_t wkspSize, + HUF_CElt* table, + const unsigned* count, + int flags) +{ + assert(srcSize > 1); /* Not supported, RLE should be used instead */ + assert(wkspSize >= sizeof(HUF_buildCTable_wksp_tables)); + + if (!(flags & HUF_flags_optimalDepth)) { + /* cheap evaluation, based on FSE */ + return FSE_optimalTableLog_internal(maxTableLog, srcSize, maxSymbolValue, 1); + } + + { BYTE* dst = (BYTE*)workSpace + sizeof(HUF_WriteCTableWksp); + size_t dstSize = wkspSize - sizeof(HUF_WriteCTableWksp); + size_t hSize, newSize; + const unsigned symbolCardinality = HUF_cardinality(count, maxSymbolValue); + const unsigned minTableLog = HUF_minTableLog(symbolCardinality); + size_t optSize = ((size_t) ~0) - 1; + unsigned optLog = maxTableLog, optLogGuess; + + DEBUGLOG(6, "HUF_optimalTableLog: probing huf depth (srcSize=%zu)", srcSize); + + /* Search until size increases */ + for (optLogGuess = minTableLog; optLogGuess <= maxTableLog; optLogGuess++) { + DEBUGLOG(7, "checking for huffLog=%u", optLogGuess); + + { size_t maxBits = HUF_buildCTable_wksp(table, count, maxSymbolValue, optLogGuess, workSpace, wkspSize); + if (ERR_isError(maxBits)) continue; + + if (maxBits < optLogGuess && optLogGuess > minTableLog) break; + + hSize = HUF_writeCTable_wksp(dst, dstSize, table, maxSymbolValue, (U32)maxBits, workSpace, wkspSize); + } + + if (ERR_isError(hSize)) continue; + + newSize = HUF_estimateCompressedSize(table, count, maxSymbolValue) + hSize; + + if (newSize > optSize + 1) { + break; + } + + if (newSize < optSize) { + optSize = newSize; + optLog = optLogGuess; + } + } + assert(optLog <= HUF_TABLELOG_MAX); + return optLog; + } +} + /* HUF_compress_internal() : * `workSpace_align4` must be aligned on 4-bytes boundaries, * and occupies the same space as a table of HUF_WORKSPACE_SIZE_U64 unsigned */ @@ -1177,14 +1333,14 @@ HUF_compress_internal (void* dst, size_t dstSize, unsigned maxSymbolValue, unsigned huffLog, HUF_nbStreams_e nbStreams, void* workSpace, size_t wkspSize, - HUF_CElt* oldHufTable, HUF_repeat* repeat, int preferRepeat, - const int bmi2, unsigned suspectUncompressible) + HUF_CElt* oldHufTable, HUF_repeat* repeat, int flags) { HUF_compress_tables_t* const table = (HUF_compress_tables_t*)HUF_alignUpWorkspace(workSpace, &wkspSize, ZSTD_ALIGNOF(size_t)); BYTE* const ostart = (BYTE*)dst; BYTE* const oend = ostart + dstSize; BYTE* op = ostart; + DEBUGLOG(5, "HUF_compress_internal (srcSize=%zu)", srcSize); HUF_STATIC_ASSERT(sizeof(*table) + HUF_WORKSPACE_MAX_ALIGNMENT <= HUF_WORKSPACE_SIZE); /* checks & inits */ @@ -1198,16 +1354,17 @@ HUF_compress_internal (void* dst, size_t dstSize, if (!huffLog) huffLog = HUF_TABLELOG_DEFAULT; /* Heuristic : If old table is valid, use it for small inputs */ - if (preferRepeat && repeat && *repeat == HUF_repeat_valid) { + if ((flags & HUF_flags_preferRepeat) && repeat && *repeat == HUF_repeat_valid) { return HUF_compressCTable_internal(ostart, op, oend, src, srcSize, - nbStreams, oldHufTable, bmi2); + nbStreams, oldHufTable, flags); } /* If uncompressible data is suspected, do a smaller sampling first */ DEBUG_STATIC_ASSERT(SUSPECT_INCOMPRESSIBLE_SAMPLE_RATIO >= 2); - if (suspectUncompressible && srcSize >= (SUSPECT_INCOMPRESSIBLE_SAMPLE_SIZE * SUSPECT_INCOMPRESSIBLE_SAMPLE_RATIO)) { + if ((flags & HUF_flags_suspectUncompressible) && srcSize >= (SUSPECT_INCOMPRESSIBLE_SAMPLE_SIZE * SUSPECT_INCOMPRESSIBLE_SAMPLE_RATIO)) { size_t largestTotal = 0; + DEBUGLOG(5, "input suspected incompressible : sampling to check"); { unsigned maxSymbolValueBegin = maxSymbolValue; CHECK_V_F(largestBegin, HIST_count_simple (table->count, &maxSymbolValueBegin, (const BYTE*)src, SUSPECT_INCOMPRESSIBLE_SAMPLE_SIZE) ); largestTotal += largestBegin; @@ -1224,6 +1381,7 @@ HUF_compress_internal (void* dst, size_t dstSize, if (largest == srcSize) { *ostart = ((const BYTE*)src)[0]; return 1; } /* single symbol, rle */ if (largest <= (srcSize >> 7)+4) return 0; /* heuristic : probably not compressible enough */ } + DEBUGLOG(6, "histogram detail completed (%zu symbols)", showU32(table->count, maxSymbolValue+1)); /* Check validity of previous table */ if ( repeat @@ -1232,25 +1390,20 @@ HUF_compress_internal (void* dst, size_t dstSize, *repeat = HUF_repeat_none; } /* Heuristic : use existing table for small inputs */ - if (preferRepeat && repeat && *repeat != HUF_repeat_none) { + if ((flags & HUF_flags_preferRepeat) && repeat && *repeat != HUF_repeat_none) { return HUF_compressCTable_internal(ostart, op, oend, src, srcSize, - nbStreams, oldHufTable, bmi2); + nbStreams, oldHufTable, flags); } /* Build Huffman Tree */ - huffLog = HUF_optimalTableLog(huffLog, srcSize, maxSymbolValue); + huffLog = HUF_optimalTableLog(huffLog, srcSize, maxSymbolValue, &table->wksps, sizeof(table->wksps), table->CTable, table->count, flags); { size_t const maxBits = HUF_buildCTable_wksp(table->CTable, table->count, maxSymbolValue, huffLog, &table->wksps.buildCTable_wksp, sizeof(table->wksps.buildCTable_wksp)); CHECK_F(maxBits); huffLog = (U32)maxBits; - } - /* Zero unused symbols in CTable, so we can check it for validity */ - { - size_t const ctableSize = HUF_CTABLE_SIZE_ST(maxSymbolValue); - size_t const unusedSize = sizeof(table->CTable) - ctableSize * sizeof(HUF_CElt); - ZSTD_memset(table->CTable + ctableSize, 0, unusedSize); + DEBUGLOG(6, "bit distribution completed (%zu symbols)", showCTableBits(table->CTable + 1, maxSymbolValue+1)); } /* Write table description header */ @@ -1263,7 +1416,7 @@ HUF_compress_internal (void* dst, size_t dstSize, if (oldSize <= hSize + newSize || hSize + 12 >= srcSize) { return HUF_compressCTable_internal(ostart, op, oend, src, srcSize, - nbStreams, oldHufTable, bmi2); + nbStreams, oldHufTable, flags); } } /* Use the new huffman table */ @@ -1275,61 +1428,35 @@ HUF_compress_internal (void* dst, size_t dstSize, } return HUF_compressCTable_internal(ostart, op, oend, src, srcSize, - nbStreams, table->CTable, bmi2); -} - - -size_t HUF_compress1X_wksp (void* dst, size_t dstSize, - const void* src, size_t srcSize, - unsigned maxSymbolValue, unsigned huffLog, - void* workSpace, size_t wkspSize) -{ - return HUF_compress_internal(dst, dstSize, src, srcSize, - maxSymbolValue, huffLog, HUF_singleStream, - workSpace, wkspSize, - NULL, NULL, 0, 0 /*bmi2*/, 0); + nbStreams, table->CTable, flags); } size_t HUF_compress1X_repeat (void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned huffLog, void* workSpace, size_t wkspSize, - HUF_CElt* hufTable, HUF_repeat* repeat, int preferRepeat, - int bmi2, unsigned suspectUncompressible) + HUF_CElt* hufTable, HUF_repeat* repeat, int flags) { + DEBUGLOG(5, "HUF_compress1X_repeat (srcSize = %zu)", srcSize); return HUF_compress_internal(dst, dstSize, src, srcSize, maxSymbolValue, huffLog, HUF_singleStream, workSpace, wkspSize, hufTable, - repeat, preferRepeat, bmi2, suspectUncompressible); -} - -/* HUF_compress4X_repeat(): - * compress input using 4 streams. - * provide workspace to generate compression tables */ -size_t HUF_compress4X_wksp (void* dst, size_t dstSize, - const void* src, size_t srcSize, - unsigned maxSymbolValue, unsigned huffLog, - void* workSpace, size_t wkspSize) -{ - return HUF_compress_internal(dst, dstSize, src, srcSize, - maxSymbolValue, huffLog, HUF_fourStreams, - workSpace, wkspSize, - NULL, NULL, 0, 0 /*bmi2*/, 0); + repeat, flags); } /* HUF_compress4X_repeat(): * compress input using 4 streams. * consider skipping quickly - * re-use an existing huffman compression table */ + * reuse an existing huffman compression table */ size_t HUF_compress4X_repeat (void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned huffLog, void* workSpace, size_t wkspSize, - HUF_CElt* hufTable, HUF_repeat* repeat, int preferRepeat, int bmi2, unsigned suspectUncompressible) + HUF_CElt* hufTable, HUF_repeat* repeat, int flags) { + DEBUGLOG(5, "HUF_compress4X_repeat (srcSize = %zu)", srcSize); return HUF_compress_internal(dst, dstSize, src, srcSize, maxSymbolValue, huffLog, HUF_fourStreams, workSpace, wkspSize, - hufTable, repeat, preferRepeat, bmi2, suspectUncompressible); + hufTable, repeat, flags); } - diff --git a/lib/zstd/compress/zstd_compress.c b/lib/zstd/compress/zstd_compress.c index 16bb995bc6c4..c41a747413e0 100644 --- a/lib/zstd/compress/zstd_compress.c +++ b/lib/zstd/compress/zstd_compress.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -11,12 +12,13 @@ /*-************************************* * Dependencies ***************************************/ +#include "../common/allocations.h" /* ZSTD_customMalloc, ZSTD_customCalloc, ZSTD_customFree */ #include "../common/zstd_deps.h" /* INT_MAX, ZSTD_memset, ZSTD_memcpy */ #include "../common/mem.h" +#include "../common/error_private.h" #include "hist.h" /* HIST_countFast_wksp */ #define FSE_STATIC_LINKING_ONLY /* FSE_encodeSymbol */ #include "../common/fse.h" -#define HUF_STATIC_LINKING_ONLY #include "../common/huf.h" #include "zstd_compress_internal.h" #include "zstd_compress_sequences.h" @@ -27,6 +29,7 @@ #include "zstd_opt.h" #include "zstd_ldm.h" #include "zstd_compress_superblock.h" +#include "../common/bits.h" /* ZSTD_highbit32, ZSTD_rotateRight_U64 */ /* *************************************************************** * Tuning parameters @@ -44,7 +47,7 @@ * in log format, aka 17 => 1 << 17 == 128Ki positions. * This structure is only used in zstd_opt. * Since allocation is centralized for all strategies, it has to be known here. - * The actual (selected) size of the hash table is then stored in ZSTD_matchState_t.hashLog3, + * The actual (selected) size of the hash table is then stored in ZSTD_MatchState_t.hashLog3, * so that zstd_opt.c doesn't need to know about this constant. */ #ifndef ZSTD_HASHLOG3_MAX @@ -55,14 +58,17 @@ * Helper functions ***************************************/ /* ZSTD_compressBound() - * Note that the result from this function is only compatible with the "normal" - * full-block strategy. - * When there are a lot of small blocks due to frequent flush in streaming mode - * the overhead of headers can make the compressed data to be larger than the - * return value of ZSTD_compressBound(). + * Note that the result from this function is only valid for + * the one-pass compression functions. + * When employing the streaming mode, + * if flushes are frequently altering the size of blocks, + * the overhead from block headers can make the compressed data larger + * than the return value of ZSTD_compressBound(). */ size_t ZSTD_compressBound(size_t srcSize) { - return ZSTD_COMPRESSBOUND(srcSize); + size_t const r = ZSTD_COMPRESSBOUND(srcSize); + if (r==0) return ERROR(srcSize_wrong); + return r; } @@ -75,12 +81,12 @@ struct ZSTD_CDict_s { ZSTD_dictContentType_e dictContentType; /* The dictContentType the CDict was created with */ U32* entropyWorkspace; /* entropy workspace of HUF_WORKSPACE_SIZE bytes */ ZSTD_cwksp workspace; - ZSTD_matchState_t matchState; + ZSTD_MatchState_t matchState; ZSTD_compressedBlockState_t cBlockState; ZSTD_customMem customMem; U32 dictID; int compressionLevel; /* 0 indicates that advanced API was used to select CDict params */ - ZSTD_paramSwitch_e useRowMatchFinder; /* Indicates whether the CDict was created with params that would use + ZSTD_ParamSwitch_e useRowMatchFinder; /* Indicates whether the CDict was created with params that would use * row-based matchfinder. Unless the cdict is reloaded, we will use * the same greedy/lazy matchfinder at compression time. */ @@ -130,11 +136,12 @@ ZSTD_CCtx* ZSTD_initStaticCCtx(void* workspace, size_t workspaceSize) ZSTD_cwksp_move(&cctx->workspace, &ws); cctx->staticSize = workspaceSize; - /* statically sized space. entropyWorkspace never moves (but prev/next block swap places) */ - if (!ZSTD_cwksp_check_available(&cctx->workspace, ENTROPY_WORKSPACE_SIZE + 2 * sizeof(ZSTD_compressedBlockState_t))) return NULL; + /* statically sized space. tmpWorkspace never moves (but prev/next block swap places) */ + if (!ZSTD_cwksp_check_available(&cctx->workspace, TMP_WORKSPACE_SIZE + 2 * sizeof(ZSTD_compressedBlockState_t))) return NULL; cctx->blockState.prevCBlock = (ZSTD_compressedBlockState_t*)ZSTD_cwksp_reserve_object(&cctx->workspace, sizeof(ZSTD_compressedBlockState_t)); cctx->blockState.nextCBlock = (ZSTD_compressedBlockState_t*)ZSTD_cwksp_reserve_object(&cctx->workspace, sizeof(ZSTD_compressedBlockState_t)); - cctx->entropyWorkspace = (U32*)ZSTD_cwksp_reserve_object(&cctx->workspace, ENTROPY_WORKSPACE_SIZE); + cctx->tmpWorkspace = ZSTD_cwksp_reserve_object(&cctx->workspace, TMP_WORKSPACE_SIZE); + cctx->tmpWkspSize = TMP_WORKSPACE_SIZE; cctx->bmi2 = ZSTD_cpuid_bmi2(ZSTD_cpuid()); return cctx; } @@ -168,15 +175,13 @@ static void ZSTD_freeCCtxContent(ZSTD_CCtx* cctx) size_t ZSTD_freeCCtx(ZSTD_CCtx* cctx) { + DEBUGLOG(3, "ZSTD_freeCCtx (address: %p)", (void*)cctx); if (cctx==NULL) return 0; /* support free on NULL */ RETURN_ERROR_IF(cctx->staticSize, memory_allocation, "not compatible with static CCtx"); - { - int cctxInWorkspace = ZSTD_cwksp_owns_buffer(&cctx->workspace, cctx); + { int cctxInWorkspace = ZSTD_cwksp_owns_buffer(&cctx->workspace, cctx); ZSTD_freeCCtxContent(cctx); - if (!cctxInWorkspace) { - ZSTD_customFree(cctx, cctx->customMem); - } + if (!cctxInWorkspace) ZSTD_customFree(cctx, cctx->customMem); } return 0; } @@ -205,7 +210,7 @@ size_t ZSTD_sizeof_CStream(const ZSTD_CStream* zcs) } /* private API call, for dictBuilder only */ -const seqStore_t* ZSTD_getSeqStore(const ZSTD_CCtx* ctx) { return &(ctx->seqStore); } +const SeqStore_t* ZSTD_getSeqStore(const ZSTD_CCtx* ctx) { return &(ctx->seqStore); } /* Returns true if the strategy supports using a row based matchfinder */ static int ZSTD_rowMatchFinderSupported(const ZSTD_strategy strategy) { @@ -215,32 +220,27 @@ static int ZSTD_rowMatchFinderSupported(const ZSTD_strategy strategy) { /* Returns true if the strategy and useRowMatchFinder mode indicate that we will use the row based matchfinder * for this compression. */ -static int ZSTD_rowMatchFinderUsed(const ZSTD_strategy strategy, const ZSTD_paramSwitch_e mode) { +static int ZSTD_rowMatchFinderUsed(const ZSTD_strategy strategy, const ZSTD_ParamSwitch_e mode) { assert(mode != ZSTD_ps_auto); return ZSTD_rowMatchFinderSupported(strategy) && (mode == ZSTD_ps_enable); } /* Returns row matchfinder usage given an initial mode and cParams */ -static ZSTD_paramSwitch_e ZSTD_resolveRowMatchFinderMode(ZSTD_paramSwitch_e mode, +static ZSTD_ParamSwitch_e ZSTD_resolveRowMatchFinderMode(ZSTD_ParamSwitch_e mode, const ZSTD_compressionParameters* const cParams) { -#if defined(ZSTD_ARCH_X86_SSE2) || defined(ZSTD_ARCH_ARM_NEON) - int const kHasSIMD128 = 1; -#else - int const kHasSIMD128 = 0; -#endif + /* The Linux Kernel does not use SIMD, and 128KB is a very common size, e.g. in BtrFS. + * The row match finder is slower for this size without SIMD, so disable it. + */ + const unsigned kWindowLogLowerBound = 17; if (mode != ZSTD_ps_auto) return mode; /* if requested enabled, but no SIMD, we still will use row matchfinder */ mode = ZSTD_ps_disable; if (!ZSTD_rowMatchFinderSupported(cParams->strategy)) return mode; - if (kHasSIMD128) { - if (cParams->windowLog > 14) mode = ZSTD_ps_enable; - } else { - if (cParams->windowLog > 17) mode = ZSTD_ps_enable; - } + if (cParams->windowLog > kWindowLogLowerBound) mode = ZSTD_ps_enable; return mode; } /* Returns block splitter usage (generally speaking, when using slower/stronger compression modes) */ -static ZSTD_paramSwitch_e ZSTD_resolveBlockSplitterMode(ZSTD_paramSwitch_e mode, +static ZSTD_ParamSwitch_e ZSTD_resolveBlockSplitterMode(ZSTD_ParamSwitch_e mode, const ZSTD_compressionParameters* const cParams) { if (mode != ZSTD_ps_auto) return mode; return (cParams->strategy >= ZSTD_btopt && cParams->windowLog >= 17) ? ZSTD_ps_enable : ZSTD_ps_disable; @@ -248,7 +248,7 @@ static ZSTD_paramSwitch_e ZSTD_resolveBlockSplitterMode(ZSTD_paramSwitch_e mode, /* Returns 1 if the arguments indicate that we should allocate a chainTable, 0 otherwise */ static int ZSTD_allocateChainTable(const ZSTD_strategy strategy, - const ZSTD_paramSwitch_e useRowMatchFinder, + const ZSTD_ParamSwitch_e useRowMatchFinder, const U32 forDDSDict) { assert(useRowMatchFinder != ZSTD_ps_auto); /* We always should allocate a chaintable if we are allocating a matchstate for a DDS dictionary matchstate. @@ -257,16 +257,44 @@ static int ZSTD_allocateChainTable(const ZSTD_strategy strategy, return forDDSDict || ((strategy != ZSTD_fast) && !ZSTD_rowMatchFinderUsed(strategy, useRowMatchFinder)); } -/* Returns 1 if compression parameters are such that we should +/* Returns ZSTD_ps_enable if compression parameters are such that we should * enable long distance matching (wlog >= 27, strategy >= btopt). - * Returns 0 otherwise. + * Returns ZSTD_ps_disable otherwise. */ -static ZSTD_paramSwitch_e ZSTD_resolveEnableLdm(ZSTD_paramSwitch_e mode, +static ZSTD_ParamSwitch_e ZSTD_resolveEnableLdm(ZSTD_ParamSwitch_e mode, const ZSTD_compressionParameters* const cParams) { if (mode != ZSTD_ps_auto) return mode; return (cParams->strategy >= ZSTD_btopt && cParams->windowLog >= 27) ? ZSTD_ps_enable : ZSTD_ps_disable; } +static int ZSTD_resolveExternalSequenceValidation(int mode) { + return mode; +} + +/* Resolves maxBlockSize to the default if no value is present. */ +static size_t ZSTD_resolveMaxBlockSize(size_t maxBlockSize) { + if (maxBlockSize == 0) { + return ZSTD_BLOCKSIZE_MAX; + } else { + return maxBlockSize; + } +} + +static ZSTD_ParamSwitch_e ZSTD_resolveExternalRepcodeSearch(ZSTD_ParamSwitch_e value, int cLevel) { + if (value != ZSTD_ps_auto) return value; + if (cLevel < 10) { + return ZSTD_ps_disable; + } else { + return ZSTD_ps_enable; + } +} + +/* Returns 1 if compression parameters are such that CDict hashtable and chaintable indices are tagged. + * If so, the tags need to be removed in ZSTD_resetCCtx_byCopyingCDict. */ +static int ZSTD_CDictIndicesAreTagged(const ZSTD_compressionParameters* const cParams) { + return cParams->strategy == ZSTD_fast || cParams->strategy == ZSTD_dfast; +} + static ZSTD_CCtx_params ZSTD_makeCCtxParamsFromCParams( ZSTD_compressionParameters cParams) { @@ -282,8 +310,12 @@ static ZSTD_CCtx_params ZSTD_makeCCtxParamsFromCParams( assert(cctxParams.ldmParams.hashLog >= cctxParams.ldmParams.bucketSizeLog); assert(cctxParams.ldmParams.hashRateLog < 32); } - cctxParams.useBlockSplitter = ZSTD_resolveBlockSplitterMode(cctxParams.useBlockSplitter, &cParams); + cctxParams.postBlockSplitter = ZSTD_resolveBlockSplitterMode(cctxParams.postBlockSplitter, &cParams); cctxParams.useRowMatchFinder = ZSTD_resolveRowMatchFinderMode(cctxParams.useRowMatchFinder, &cParams); + cctxParams.validateSequences = ZSTD_resolveExternalSequenceValidation(cctxParams.validateSequences); + cctxParams.maxBlockSize = ZSTD_resolveMaxBlockSize(cctxParams.maxBlockSize); + cctxParams.searchForExternalRepcodes = ZSTD_resolveExternalRepcodeSearch(cctxParams.searchForExternalRepcodes, + cctxParams.compressionLevel); assert(!ZSTD_checkCParams(cParams)); return cctxParams; } @@ -329,10 +361,13 @@ size_t ZSTD_CCtxParams_init(ZSTD_CCtx_params* cctxParams, int compressionLevel) #define ZSTD_NO_CLEVEL 0 /* - * Initializes the cctxParams from params and compressionLevel. + * Initializes `cctxParams` from `params` and `compressionLevel`. * @param compressionLevel If params are derived from a compression level then that compression level, otherwise ZSTD_NO_CLEVEL. */ -static void ZSTD_CCtxParams_init_internal(ZSTD_CCtx_params* cctxParams, ZSTD_parameters const* params, int compressionLevel) +static void +ZSTD_CCtxParams_init_internal(ZSTD_CCtx_params* cctxParams, + const ZSTD_parameters* params, + int compressionLevel) { assert(!ZSTD_checkCParams(params->cParams)); ZSTD_memset(cctxParams, 0, sizeof(*cctxParams)); @@ -343,10 +378,13 @@ static void ZSTD_CCtxParams_init_internal(ZSTD_CCtx_params* cctxParams, ZSTD_par */ cctxParams->compressionLevel = compressionLevel; cctxParams->useRowMatchFinder = ZSTD_resolveRowMatchFinderMode(cctxParams->useRowMatchFinder, ¶ms->cParams); - cctxParams->useBlockSplitter = ZSTD_resolveBlockSplitterMode(cctxParams->useBlockSplitter, ¶ms->cParams); + cctxParams->postBlockSplitter = ZSTD_resolveBlockSplitterMode(cctxParams->postBlockSplitter, ¶ms->cParams); cctxParams->ldmParams.enableLdm = ZSTD_resolveEnableLdm(cctxParams->ldmParams.enableLdm, ¶ms->cParams); + cctxParams->validateSequences = ZSTD_resolveExternalSequenceValidation(cctxParams->validateSequences); + cctxParams->maxBlockSize = ZSTD_resolveMaxBlockSize(cctxParams->maxBlockSize); + cctxParams->searchForExternalRepcodes = ZSTD_resolveExternalRepcodeSearch(cctxParams->searchForExternalRepcodes, compressionLevel); DEBUGLOG(4, "ZSTD_CCtxParams_init_internal: useRowMatchFinder=%d, useBlockSplitter=%d ldm=%d", - cctxParams->useRowMatchFinder, cctxParams->useBlockSplitter, cctxParams->ldmParams.enableLdm); + cctxParams->useRowMatchFinder, cctxParams->postBlockSplitter, cctxParams->ldmParams.enableLdm); } size_t ZSTD_CCtxParams_init_advanced(ZSTD_CCtx_params* cctxParams, ZSTD_parameters params) @@ -359,7 +397,7 @@ size_t ZSTD_CCtxParams_init_advanced(ZSTD_CCtx_params* cctxParams, ZSTD_paramete /* * Sets cctxParams' cParams and fParams from params, but otherwise leaves them alone. - * @param param Validated zstd parameters. + * @param params Validated zstd parameters. */ static void ZSTD_CCtxParams_setZstdParams( ZSTD_CCtx_params* cctxParams, const ZSTD_parameters* params) @@ -455,8 +493,8 @@ ZSTD_bounds ZSTD_cParam_getBounds(ZSTD_cParameter param) return bounds; case ZSTD_c_enableLongDistanceMatching: - bounds.lowerBound = 0; - bounds.upperBound = 1; + bounds.lowerBound = (int)ZSTD_ps_auto; + bounds.upperBound = (int)ZSTD_ps_disable; return bounds; case ZSTD_c_ldmHashLog: @@ -534,11 +572,16 @@ ZSTD_bounds ZSTD_cParam_getBounds(ZSTD_cParameter param) bounds.upperBound = 1; return bounds; - case ZSTD_c_useBlockSplitter: + case ZSTD_c_splitAfterSequences: bounds.lowerBound = (int)ZSTD_ps_auto; bounds.upperBound = (int)ZSTD_ps_disable; return bounds; + case ZSTD_c_blockSplitterLevel: + bounds.lowerBound = 0; + bounds.upperBound = ZSTD_BLOCKSPLITTER_LEVEL_MAX; + return bounds; + case ZSTD_c_useRowMatchFinder: bounds.lowerBound = (int)ZSTD_ps_auto; bounds.upperBound = (int)ZSTD_ps_disable; @@ -549,6 +592,26 @@ ZSTD_bounds ZSTD_cParam_getBounds(ZSTD_cParameter param) bounds.upperBound = 1; return bounds; + case ZSTD_c_prefetchCDictTables: + bounds.lowerBound = (int)ZSTD_ps_auto; + bounds.upperBound = (int)ZSTD_ps_disable; + return bounds; + + case ZSTD_c_enableSeqProducerFallback: + bounds.lowerBound = 0; + bounds.upperBound = 1; + return bounds; + + case ZSTD_c_maxBlockSize: + bounds.lowerBound = ZSTD_BLOCKSIZE_MAX_MIN; + bounds.upperBound = ZSTD_BLOCKSIZE_MAX; + return bounds; + + case ZSTD_c_repcodeResolution: + bounds.lowerBound = (int)ZSTD_ps_auto; + bounds.upperBound = (int)ZSTD_ps_disable; + return bounds; + default: bounds.error = ERROR(parameter_unsupported); return bounds; @@ -567,10 +630,11 @@ static size_t ZSTD_cParam_clampBounds(ZSTD_cParameter cParam, int* value) return 0; } -#define BOUNDCHECK(cParam, val) { \ - RETURN_ERROR_IF(!ZSTD_cParam_withinBounds(cParam,val), \ - parameter_outOfBound, "Param out of bounds"); \ -} +#define BOUNDCHECK(cParam, val) \ + do { \ + RETURN_ERROR_IF(!ZSTD_cParam_withinBounds(cParam,val), \ + parameter_outOfBound, "Param out of bounds"); \ + } while (0) static int ZSTD_isUpdateAuthorized(ZSTD_cParameter param) @@ -584,6 +648,7 @@ static int ZSTD_isUpdateAuthorized(ZSTD_cParameter param) case ZSTD_c_minMatch: case ZSTD_c_targetLength: case ZSTD_c_strategy: + case ZSTD_c_blockSplitterLevel: return 1; case ZSTD_c_format: @@ -610,9 +675,13 @@ static int ZSTD_isUpdateAuthorized(ZSTD_cParameter param) case ZSTD_c_stableOutBuffer: case ZSTD_c_blockDelimiters: case ZSTD_c_validateSequences: - case ZSTD_c_useBlockSplitter: + case ZSTD_c_splitAfterSequences: case ZSTD_c_useRowMatchFinder: case ZSTD_c_deterministicRefPrefix: + case ZSTD_c_prefetchCDictTables: + case ZSTD_c_enableSeqProducerFallback: + case ZSTD_c_maxBlockSize: + case ZSTD_c_repcodeResolution: default: return 0; } @@ -625,7 +694,7 @@ size_t ZSTD_CCtx_setParameter(ZSTD_CCtx* cctx, ZSTD_cParameter param, int value) if (ZSTD_isUpdateAuthorized(param)) { cctx->cParamsChanged = 1; } else { - RETURN_ERROR(stage_wrong, "can only set params in ctx init stage"); + RETURN_ERROR(stage_wrong, "can only set params in cctx init stage"); } } switch(param) @@ -665,9 +734,14 @@ size_t ZSTD_CCtx_setParameter(ZSTD_CCtx* cctx, ZSTD_cParameter param, int value) case ZSTD_c_stableOutBuffer: case ZSTD_c_blockDelimiters: case ZSTD_c_validateSequences: - case ZSTD_c_useBlockSplitter: + case ZSTD_c_splitAfterSequences: + case ZSTD_c_blockSplitterLevel: case ZSTD_c_useRowMatchFinder: case ZSTD_c_deterministicRefPrefix: + case ZSTD_c_prefetchCDictTables: + case ZSTD_c_enableSeqProducerFallback: + case ZSTD_c_maxBlockSize: + case ZSTD_c_repcodeResolution: break; default: RETURN_ERROR(parameter_unsupported, "unknown parameter"); @@ -723,12 +797,12 @@ size_t ZSTD_CCtxParams_setParameter(ZSTD_CCtx_params* CCtxParams, case ZSTD_c_minMatch : if (value!=0) /* 0 => use default */ BOUNDCHECK(ZSTD_c_minMatch, value); - CCtxParams->cParams.minMatch = value; + CCtxParams->cParams.minMatch = (U32)value; return CCtxParams->cParams.minMatch; case ZSTD_c_targetLength : BOUNDCHECK(ZSTD_c_targetLength, value); - CCtxParams->cParams.targetLength = value; + CCtxParams->cParams.targetLength = (U32)value; return CCtxParams->cParams.targetLength; case ZSTD_c_strategy : @@ -741,12 +815,12 @@ size_t ZSTD_CCtxParams_setParameter(ZSTD_CCtx_params* CCtxParams, /* Content size written in frame header _when known_ (default:1) */ DEBUGLOG(4, "set content size flag = %u", (value!=0)); CCtxParams->fParams.contentSizeFlag = value != 0; - return CCtxParams->fParams.contentSizeFlag; + return (size_t)CCtxParams->fParams.contentSizeFlag; case ZSTD_c_checksumFlag : /* A 32-bits content checksum will be calculated and written at end of frame (default:0) */ CCtxParams->fParams.checksumFlag = value != 0; - return CCtxParams->fParams.checksumFlag; + return (size_t)CCtxParams->fParams.checksumFlag; case ZSTD_c_dictIDFlag : /* When applicable, dictionary's dictID is provided in frame header (default:1) */ DEBUGLOG(4, "set dictIDFlag = %u", (value!=0)); @@ -755,18 +829,18 @@ size_t ZSTD_CCtxParams_setParameter(ZSTD_CCtx_params* CCtxParams, case ZSTD_c_forceMaxWindow : CCtxParams->forceWindow = (value != 0); - return CCtxParams->forceWindow; + return (size_t)CCtxParams->forceWindow; case ZSTD_c_forceAttachDict : { const ZSTD_dictAttachPref_e pref = (ZSTD_dictAttachPref_e)value; - BOUNDCHECK(ZSTD_c_forceAttachDict, pref); + BOUNDCHECK(ZSTD_c_forceAttachDict, (int)pref); CCtxParams->attachDictPref = pref; return CCtxParams->attachDictPref; } case ZSTD_c_literalCompressionMode : { - const ZSTD_paramSwitch_e lcm = (ZSTD_paramSwitch_e)value; - BOUNDCHECK(ZSTD_c_literalCompressionMode, lcm); + const ZSTD_ParamSwitch_e lcm = (ZSTD_ParamSwitch_e)value; + BOUNDCHECK(ZSTD_c_literalCompressionMode, (int)lcm); CCtxParams->literalCompressionMode = lcm; return CCtxParams->literalCompressionMode; } @@ -789,47 +863,50 @@ size_t ZSTD_CCtxParams_setParameter(ZSTD_CCtx_params* CCtxParams, case ZSTD_c_enableDedicatedDictSearch : CCtxParams->enableDedicatedDictSearch = (value!=0); - return CCtxParams->enableDedicatedDictSearch; + return (size_t)CCtxParams->enableDedicatedDictSearch; case ZSTD_c_enableLongDistanceMatching : - CCtxParams->ldmParams.enableLdm = (ZSTD_paramSwitch_e)value; + BOUNDCHECK(ZSTD_c_enableLongDistanceMatching, value); + CCtxParams->ldmParams.enableLdm = (ZSTD_ParamSwitch_e)value; return CCtxParams->ldmParams.enableLdm; case ZSTD_c_ldmHashLog : if (value!=0) /* 0 ==> auto */ BOUNDCHECK(ZSTD_c_ldmHashLog, value); - CCtxParams->ldmParams.hashLog = value; + CCtxParams->ldmParams.hashLog = (U32)value; return CCtxParams->ldmParams.hashLog; case ZSTD_c_ldmMinMatch : if (value!=0) /* 0 ==> default */ BOUNDCHECK(ZSTD_c_ldmMinMatch, value); - CCtxParams->ldmParams.minMatchLength = value; + CCtxParams->ldmParams.minMatchLength = (U32)value; return CCtxParams->ldmParams.minMatchLength; case ZSTD_c_ldmBucketSizeLog : if (value!=0) /* 0 ==> default */ BOUNDCHECK(ZSTD_c_ldmBucketSizeLog, value); - CCtxParams->ldmParams.bucketSizeLog = value; + CCtxParams->ldmParams.bucketSizeLog = (U32)value; return CCtxParams->ldmParams.bucketSizeLog; case ZSTD_c_ldmHashRateLog : if (value!=0) /* 0 ==> default */ BOUNDCHECK(ZSTD_c_ldmHashRateLog, value); - CCtxParams->ldmParams.hashRateLog = value; + CCtxParams->ldmParams.hashRateLog = (U32)value; return CCtxParams->ldmParams.hashRateLog; case ZSTD_c_targetCBlockSize : - if (value!=0) /* 0 ==> default */ + if (value!=0) { /* 0 ==> default */ + value = MAX(value, ZSTD_TARGETCBLOCKSIZE_MIN); BOUNDCHECK(ZSTD_c_targetCBlockSize, value); - CCtxParams->targetCBlockSize = value; + } + CCtxParams->targetCBlockSize = (U32)value; return CCtxParams->targetCBlockSize; case ZSTD_c_srcSizeHint : if (value!=0) /* 0 ==> default */ BOUNDCHECK(ZSTD_c_srcSizeHint, value); CCtxParams->srcSizeHint = value; - return CCtxParams->srcSizeHint; + return (size_t)CCtxParams->srcSizeHint; case ZSTD_c_stableInBuffer: BOUNDCHECK(ZSTD_c_stableInBuffer, value); @@ -843,28 +920,55 @@ size_t ZSTD_CCtxParams_setParameter(ZSTD_CCtx_params* CCtxParams, case ZSTD_c_blockDelimiters: BOUNDCHECK(ZSTD_c_blockDelimiters, value); - CCtxParams->blockDelimiters = (ZSTD_sequenceFormat_e)value; + CCtxParams->blockDelimiters = (ZSTD_SequenceFormat_e)value; return CCtxParams->blockDelimiters; case ZSTD_c_validateSequences: BOUNDCHECK(ZSTD_c_validateSequences, value); CCtxParams->validateSequences = value; - return CCtxParams->validateSequences; + return (size_t)CCtxParams->validateSequences; + + case ZSTD_c_splitAfterSequences: + BOUNDCHECK(ZSTD_c_splitAfterSequences, value); + CCtxParams->postBlockSplitter = (ZSTD_ParamSwitch_e)value; + return CCtxParams->postBlockSplitter; - case ZSTD_c_useBlockSplitter: - BOUNDCHECK(ZSTD_c_useBlockSplitter, value); - CCtxParams->useBlockSplitter = (ZSTD_paramSwitch_e)value; - return CCtxParams->useBlockSplitter; + case ZSTD_c_blockSplitterLevel: + BOUNDCHECK(ZSTD_c_blockSplitterLevel, value); + CCtxParams->preBlockSplitter_level = value; + return (size_t)CCtxParams->preBlockSplitter_level; case ZSTD_c_useRowMatchFinder: BOUNDCHECK(ZSTD_c_useRowMatchFinder, value); - CCtxParams->useRowMatchFinder = (ZSTD_paramSwitch_e)value; + CCtxParams->useRowMatchFinder = (ZSTD_ParamSwitch_e)value; return CCtxParams->useRowMatchFinder; case ZSTD_c_deterministicRefPrefix: BOUNDCHECK(ZSTD_c_deterministicRefPrefix, value); CCtxParams->deterministicRefPrefix = !!value; - return CCtxParams->deterministicRefPrefix; + return (size_t)CCtxParams->deterministicRefPrefix; + + case ZSTD_c_prefetchCDictTables: + BOUNDCHECK(ZSTD_c_prefetchCDictTables, value); + CCtxParams->prefetchCDictTables = (ZSTD_ParamSwitch_e)value; + return CCtxParams->prefetchCDictTables; + + case ZSTD_c_enableSeqProducerFallback: + BOUNDCHECK(ZSTD_c_enableSeqProducerFallback, value); + CCtxParams->enableMatchFinderFallback = value; + return (size_t)CCtxParams->enableMatchFinderFallback; + + case ZSTD_c_maxBlockSize: + if (value!=0) /* 0 ==> default */ + BOUNDCHECK(ZSTD_c_maxBlockSize, value); + assert(value>=0); + CCtxParams->maxBlockSize = (size_t)value; + return CCtxParams->maxBlockSize; + + case ZSTD_c_repcodeResolution: + BOUNDCHECK(ZSTD_c_repcodeResolution, value); + CCtxParams->searchForExternalRepcodes = (ZSTD_ParamSwitch_e)value; + return CCtxParams->searchForExternalRepcodes; default: RETURN_ERROR(parameter_unsupported, "unknown parameter"); } @@ -881,7 +985,7 @@ size_t ZSTD_CCtxParams_getParameter( switch(param) { case ZSTD_c_format : - *value = CCtxParams->format; + *value = (int)CCtxParams->format; break; case ZSTD_c_compressionLevel : *value = CCtxParams->compressionLevel; @@ -896,16 +1000,16 @@ size_t ZSTD_CCtxParams_getParameter( *value = (int)CCtxParams->cParams.chainLog; break; case ZSTD_c_searchLog : - *value = CCtxParams->cParams.searchLog; + *value = (int)CCtxParams->cParams.searchLog; break; case ZSTD_c_minMatch : - *value = CCtxParams->cParams.minMatch; + *value = (int)CCtxParams->cParams.minMatch; break; case ZSTD_c_targetLength : - *value = CCtxParams->cParams.targetLength; + *value = (int)CCtxParams->cParams.targetLength; break; case ZSTD_c_strategy : - *value = (unsigned)CCtxParams->cParams.strategy; + *value = (int)CCtxParams->cParams.strategy; break; case ZSTD_c_contentSizeFlag : *value = CCtxParams->fParams.contentSizeFlag; @@ -920,10 +1024,10 @@ size_t ZSTD_CCtxParams_getParameter( *value = CCtxParams->forceWindow; break; case ZSTD_c_forceAttachDict : - *value = CCtxParams->attachDictPref; + *value = (int)CCtxParams->attachDictPref; break; case ZSTD_c_literalCompressionMode : - *value = CCtxParams->literalCompressionMode; + *value = (int)CCtxParams->literalCompressionMode; break; case ZSTD_c_nbWorkers : assert(CCtxParams->nbWorkers == 0); @@ -939,19 +1043,19 @@ size_t ZSTD_CCtxParams_getParameter( *value = CCtxParams->enableDedicatedDictSearch; break; case ZSTD_c_enableLongDistanceMatching : - *value = CCtxParams->ldmParams.enableLdm; + *value = (int)CCtxParams->ldmParams.enableLdm; break; case ZSTD_c_ldmHashLog : - *value = CCtxParams->ldmParams.hashLog; + *value = (int)CCtxParams->ldmParams.hashLog; break; case ZSTD_c_ldmMinMatch : - *value = CCtxParams->ldmParams.minMatchLength; + *value = (int)CCtxParams->ldmParams.minMatchLength; break; case ZSTD_c_ldmBucketSizeLog : - *value = CCtxParams->ldmParams.bucketSizeLog; + *value = (int)CCtxParams->ldmParams.bucketSizeLog; break; case ZSTD_c_ldmHashRateLog : - *value = CCtxParams->ldmParams.hashRateLog; + *value = (int)CCtxParams->ldmParams.hashRateLog; break; case ZSTD_c_targetCBlockSize : *value = (int)CCtxParams->targetCBlockSize; @@ -971,8 +1075,11 @@ size_t ZSTD_CCtxParams_getParameter( case ZSTD_c_validateSequences : *value = (int)CCtxParams->validateSequences; break; - case ZSTD_c_useBlockSplitter : - *value = (int)CCtxParams->useBlockSplitter; + case ZSTD_c_splitAfterSequences : + *value = (int)CCtxParams->postBlockSplitter; + break; + case ZSTD_c_blockSplitterLevel : + *value = CCtxParams->preBlockSplitter_level; break; case ZSTD_c_useRowMatchFinder : *value = (int)CCtxParams->useRowMatchFinder; @@ -980,6 +1087,18 @@ size_t ZSTD_CCtxParams_getParameter( case ZSTD_c_deterministicRefPrefix: *value = (int)CCtxParams->deterministicRefPrefix; break; + case ZSTD_c_prefetchCDictTables: + *value = (int)CCtxParams->prefetchCDictTables; + break; + case ZSTD_c_enableSeqProducerFallback: + *value = CCtxParams->enableMatchFinderFallback; + break; + case ZSTD_c_maxBlockSize: + *value = (int)CCtxParams->maxBlockSize; + break; + case ZSTD_c_repcodeResolution: + *value = (int)CCtxParams->searchForExternalRepcodes; + break; default: RETURN_ERROR(parameter_unsupported, "unknown parameter"); } return 0; @@ -1006,9 +1125,47 @@ size_t ZSTD_CCtx_setParametersUsingCCtxParams( return 0; } +size_t ZSTD_CCtx_setCParams(ZSTD_CCtx* cctx, ZSTD_compressionParameters cparams) +{ + ZSTD_STATIC_ASSERT(sizeof(cparams) == 7 * 4 /* all params are listed below */); + DEBUGLOG(4, "ZSTD_CCtx_setCParams"); + /* only update if all parameters are valid */ + FORWARD_IF_ERROR(ZSTD_checkCParams(cparams), ""); + FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_windowLog, (int)cparams.windowLog), ""); + FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_chainLog, (int)cparams.chainLog), ""); + FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_hashLog, (int)cparams.hashLog), ""); + FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_searchLog, (int)cparams.searchLog), ""); + FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_minMatch, (int)cparams.minMatch), ""); + FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_targetLength, (int)cparams.targetLength), ""); + FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_strategy, (int)cparams.strategy), ""); + return 0; +} + +size_t ZSTD_CCtx_setFParams(ZSTD_CCtx* cctx, ZSTD_frameParameters fparams) +{ + ZSTD_STATIC_ASSERT(sizeof(fparams) == 3 * 4 /* all params are listed below */); + DEBUGLOG(4, "ZSTD_CCtx_setFParams"); + FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_contentSizeFlag, fparams.contentSizeFlag != 0), ""); + FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_checksumFlag, fparams.checksumFlag != 0), ""); + FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_dictIDFlag, fparams.noDictIDFlag == 0), ""); + return 0; +} + +size_t ZSTD_CCtx_setParams(ZSTD_CCtx* cctx, ZSTD_parameters params) +{ + DEBUGLOG(4, "ZSTD_CCtx_setParams"); + /* First check cParams, because we want to update all or none. */ + FORWARD_IF_ERROR(ZSTD_checkCParams(params.cParams), ""); + /* Next set fParams, because this could fail if the cctx isn't in init stage. */ + FORWARD_IF_ERROR(ZSTD_CCtx_setFParams(cctx, params.fParams), ""); + /* Finally set cParams, which should succeed. */ + FORWARD_IF_ERROR(ZSTD_CCtx_setCParams(cctx, params.cParams), ""); + return 0; +} + size_t ZSTD_CCtx_setPledgedSrcSize(ZSTD_CCtx* cctx, unsigned long long pledgedSrcSize) { - DEBUGLOG(4, "ZSTD_CCtx_setPledgedSrcSize to %u bytes", (U32)pledgedSrcSize); + DEBUGLOG(4, "ZSTD_CCtx_setPledgedSrcSize to %llu bytes", pledgedSrcSize); RETURN_ERROR_IF(cctx->streamStage != zcss_init, stage_wrong, "Can't set pledgedSrcSize when not in init stage."); cctx->pledgedSrcSizePlusOne = pledgedSrcSize+1; @@ -1024,9 +1181,9 @@ static void ZSTD_dedicatedDictSearch_revertCParams( ZSTD_compressionParameters* cParams); /* - * Initializes the local dict using the requested parameters. - * NOTE: This does not use the pledged src size, because it may be used for more - * than one compression. + * Initializes the local dictionary using requested parameters. + * NOTE: Initialization does not employ the pledged src size, + * because the dictionary may be used for multiple compressions. */ static size_t ZSTD_initLocalDict(ZSTD_CCtx* cctx) { @@ -1039,8 +1196,8 @@ static size_t ZSTD_initLocalDict(ZSTD_CCtx* cctx) return 0; } if (dl->cdict != NULL) { - assert(cctx->cdict == dl->cdict); /* Local dictionary already initialized. */ + assert(cctx->cdict == dl->cdict); return 0; } assert(dl->dictSize > 0); @@ -1060,26 +1217,30 @@ static size_t ZSTD_initLocalDict(ZSTD_CCtx* cctx) } size_t ZSTD_CCtx_loadDictionary_advanced( - ZSTD_CCtx* cctx, const void* dict, size_t dictSize, - ZSTD_dictLoadMethod_e dictLoadMethod, ZSTD_dictContentType_e dictContentType) + ZSTD_CCtx* cctx, + const void* dict, size_t dictSize, + ZSTD_dictLoadMethod_e dictLoadMethod, + ZSTD_dictContentType_e dictContentType) { - RETURN_ERROR_IF(cctx->streamStage != zcss_init, stage_wrong, - "Can't load a dictionary when ctx is not in init stage."); DEBUGLOG(4, "ZSTD_CCtx_loadDictionary_advanced (size: %u)", (U32)dictSize); - ZSTD_clearAllDicts(cctx); /* in case one already exists */ - if (dict == NULL || dictSize == 0) /* no dictionary mode */ + RETURN_ERROR_IF(cctx->streamStage != zcss_init, stage_wrong, + "Can't load a dictionary when cctx is not in init stage."); + ZSTD_clearAllDicts(cctx); /* erase any previously set dictionary */ + if (dict == NULL || dictSize == 0) /* no dictionary */ return 0; if (dictLoadMethod == ZSTD_dlm_byRef) { cctx->localDict.dict = dict; } else { + /* copy dictionary content inside CCtx to own its lifetime */ void* dictBuffer; RETURN_ERROR_IF(cctx->staticSize, memory_allocation, - "no malloc for static CCtx"); + "static CCtx can't allocate for an internal copy of dictionary"); dictBuffer = ZSTD_customMalloc(dictSize, cctx->customMem); - RETURN_ERROR_IF(!dictBuffer, memory_allocation, "NULL pointer!"); + RETURN_ERROR_IF(dictBuffer==NULL, memory_allocation, + "allocation failed for dictionary content"); ZSTD_memcpy(dictBuffer, dict, dictSize); - cctx->localDict.dictBuffer = dictBuffer; - cctx->localDict.dict = dictBuffer; + cctx->localDict.dictBuffer = dictBuffer; /* owned ptr to free */ + cctx->localDict.dict = dictBuffer; /* read-only reference */ } cctx->localDict.dictSize = dictSize; cctx->localDict.dictContentType = dictContentType; @@ -1149,7 +1310,7 @@ size_t ZSTD_CCtx_reset(ZSTD_CCtx* cctx, ZSTD_ResetDirective reset) if ( (reset == ZSTD_reset_parameters) || (reset == ZSTD_reset_session_and_parameters) ) { RETURN_ERROR_IF(cctx->streamStage != zcss_init, stage_wrong, - "Can't reset parameters only when not in init stage."); + "Reset parameters is only possible during init stage."); ZSTD_clearAllDicts(cctx); return ZSTD_CCtxParams_reset(&cctx->requestedParams); } @@ -1168,7 +1329,7 @@ size_t ZSTD_checkCParams(ZSTD_compressionParameters cParams) BOUNDCHECK(ZSTD_c_searchLog, (int)cParams.searchLog); BOUNDCHECK(ZSTD_c_minMatch, (int)cParams.minMatch); BOUNDCHECK(ZSTD_c_targetLength,(int)cParams.targetLength); - BOUNDCHECK(ZSTD_c_strategy, cParams.strategy); + BOUNDCHECK(ZSTD_c_strategy, (int)cParams.strategy); return 0; } @@ -1178,11 +1339,12 @@ size_t ZSTD_checkCParams(ZSTD_compressionParameters cParams) static ZSTD_compressionParameters ZSTD_clampCParams(ZSTD_compressionParameters cParams) { -# define CLAMP_TYPE(cParam, val, type) { \ - ZSTD_bounds const bounds = ZSTD_cParam_getBounds(cParam); \ - if ((int)valbounds.upperBound) val=(type)bounds.upperBound; \ - } +# define CLAMP_TYPE(cParam, val, type) \ + do { \ + ZSTD_bounds const bounds = ZSTD_cParam_getBounds(cParam); \ + if ((int)valbounds.upperBound) val=(type)bounds.upperBound; \ + } while (0) # define CLAMP(cParam, val) CLAMP_TYPE(cParam, val, unsigned) CLAMP(ZSTD_c_windowLog, cParams.windowLog); CLAMP(ZSTD_c_chainLog, cParams.chainLog); @@ -1240,19 +1402,62 @@ static U32 ZSTD_dictAndWindowLog(U32 windowLog, U64 srcSize, U64 dictSize) * optimize `cPar` for a specified input (`srcSize` and `dictSize`). * mostly downsize to reduce memory consumption and initialization latency. * `srcSize` can be ZSTD_CONTENTSIZE_UNKNOWN when not known. - * `mode` is the mode for parameter adjustment. See docs for `ZSTD_cParamMode_e`. + * `mode` is the mode for parameter adjustment. See docs for `ZSTD_CParamMode_e`. * note : `srcSize==0` means 0! * condition : cPar is presumed validated (can be checked using ZSTD_checkCParams()). */ static ZSTD_compressionParameters ZSTD_adjustCParams_internal(ZSTD_compressionParameters cPar, unsigned long long srcSize, size_t dictSize, - ZSTD_cParamMode_e mode) + ZSTD_CParamMode_e mode, + ZSTD_ParamSwitch_e useRowMatchFinder) { const U64 minSrcSize = 513; /* (1<<9) + 1 */ const U64 maxWindowResize = 1ULL << (ZSTD_WINDOWLOG_MAX-1); assert(ZSTD_checkCParams(cPar)==0); + /* Cascade the selected strategy down to the next-highest one built into + * this binary. */ +#ifdef ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR + if (cPar.strategy == ZSTD_btultra2) { + cPar.strategy = ZSTD_btultra; + } + if (cPar.strategy == ZSTD_btultra) { + cPar.strategy = ZSTD_btopt; + } +#endif +#ifdef ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR + if (cPar.strategy == ZSTD_btopt) { + cPar.strategy = ZSTD_btlazy2; + } +#endif +#ifdef ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR + if (cPar.strategy == ZSTD_btlazy2) { + cPar.strategy = ZSTD_lazy2; + } +#endif +#ifdef ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR + if (cPar.strategy == ZSTD_lazy2) { + cPar.strategy = ZSTD_lazy; + } +#endif +#ifdef ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR + if (cPar.strategy == ZSTD_lazy) { + cPar.strategy = ZSTD_greedy; + } +#endif +#ifdef ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR + if (cPar.strategy == ZSTD_greedy) { + cPar.strategy = ZSTD_dfast; + } +#endif +#ifdef ZSTD_EXCLUDE_DFAST_BLOCK_COMPRESSOR + if (cPar.strategy == ZSTD_dfast) { + cPar.strategy = ZSTD_fast; + cPar.targetLength = 0; + } +#endif + switch (mode) { case ZSTD_cpm_unknown: case ZSTD_cpm_noAttachDict: @@ -1281,8 +1486,8 @@ ZSTD_adjustCParams_internal(ZSTD_compressionParameters cPar, } /* resize windowLog if input is small enough, to use less memory */ - if ( (srcSize < maxWindowResize) - && (dictSize < maxWindowResize) ) { + if ( (srcSize <= maxWindowResize) + && (dictSize <= maxWindowResize) ) { U32 const tSize = (U32)(srcSize + dictSize); static U32 const hashSizeMin = 1 << ZSTD_HASHLOG_MIN; U32 const srcLog = (tSize < hashSizeMin) ? ZSTD_HASHLOG_MIN : @@ -1300,6 +1505,42 @@ ZSTD_adjustCParams_internal(ZSTD_compressionParameters cPar, if (cPar.windowLog < ZSTD_WINDOWLOG_ABSOLUTEMIN) cPar.windowLog = ZSTD_WINDOWLOG_ABSOLUTEMIN; /* minimum wlog required for valid frame header */ + /* We can't use more than 32 bits of hash in total, so that means that we require: + * (hashLog + 8) <= 32 && (chainLog + 8) <= 32 + */ + if (mode == ZSTD_cpm_createCDict && ZSTD_CDictIndicesAreTagged(&cPar)) { + U32 const maxShortCacheHashLog = 32 - ZSTD_SHORT_CACHE_TAG_BITS; + if (cPar.hashLog > maxShortCacheHashLog) { + cPar.hashLog = maxShortCacheHashLog; + } + if (cPar.chainLog > maxShortCacheHashLog) { + cPar.chainLog = maxShortCacheHashLog; + } + } + + + /* At this point, we aren't 100% sure if we are using the row match finder. + * Unless it is explicitly disabled, conservatively assume that it is enabled. + * In this case it will only be disabled for small sources, so shrinking the + * hash log a little bit shouldn't result in any ratio loss. + */ + if (useRowMatchFinder == ZSTD_ps_auto) + useRowMatchFinder = ZSTD_ps_enable; + + /* We can't hash more than 32-bits in total. So that means that we require: + * (hashLog - rowLog + 8) <= 32 + */ + if (ZSTD_rowMatchFinderUsed(cPar.strategy, useRowMatchFinder)) { + /* Switch to 32-entry rows if searchLog is 5 (or more) */ + U32 const rowLog = BOUNDED(4, cPar.searchLog, 6); + U32 const maxRowHashLog = 32 - ZSTD_ROW_HASH_TAG_BITS; + U32 const maxHashLog = maxRowHashLog + rowLog; + assert(cPar.hashLog >= rowLog); + if (cPar.hashLog > maxHashLog) { + cPar.hashLog = maxHashLog; + } + } + return cPar; } @@ -1310,11 +1551,11 @@ ZSTD_adjustCParams(ZSTD_compressionParameters cPar, { cPar = ZSTD_clampCParams(cPar); /* resulting cPar is necessarily valid (all parameters within range) */ if (srcSize == 0) srcSize = ZSTD_CONTENTSIZE_UNKNOWN; - return ZSTD_adjustCParams_internal(cPar, srcSize, dictSize, ZSTD_cpm_unknown); + return ZSTD_adjustCParams_internal(cPar, srcSize, dictSize, ZSTD_cpm_unknown, ZSTD_ps_auto); } -static ZSTD_compressionParameters ZSTD_getCParams_internal(int compressionLevel, unsigned long long srcSizeHint, size_t dictSize, ZSTD_cParamMode_e mode); -static ZSTD_parameters ZSTD_getParams_internal(int compressionLevel, unsigned long long srcSizeHint, size_t dictSize, ZSTD_cParamMode_e mode); +static ZSTD_compressionParameters ZSTD_getCParams_internal(int compressionLevel, unsigned long long srcSizeHint, size_t dictSize, ZSTD_CParamMode_e mode); +static ZSTD_parameters ZSTD_getParams_internal(int compressionLevel, unsigned long long srcSizeHint, size_t dictSize, ZSTD_CParamMode_e mode); static void ZSTD_overrideCParams( ZSTD_compressionParameters* cParams, @@ -1330,24 +1571,25 @@ static void ZSTD_overrideCParams( } ZSTD_compressionParameters ZSTD_getCParamsFromCCtxParams( - const ZSTD_CCtx_params* CCtxParams, U64 srcSizeHint, size_t dictSize, ZSTD_cParamMode_e mode) + const ZSTD_CCtx_params* CCtxParams, U64 srcSizeHint, size_t dictSize, ZSTD_CParamMode_e mode) { ZSTD_compressionParameters cParams; if (srcSizeHint == ZSTD_CONTENTSIZE_UNKNOWN && CCtxParams->srcSizeHint > 0) { - srcSizeHint = CCtxParams->srcSizeHint; + assert(CCtxParams->srcSizeHint>=0); + srcSizeHint = (U64)CCtxParams->srcSizeHint; } cParams = ZSTD_getCParams_internal(CCtxParams->compressionLevel, srcSizeHint, dictSize, mode); if (CCtxParams->ldmParams.enableLdm == ZSTD_ps_enable) cParams.windowLog = ZSTD_LDM_DEFAULT_WINDOW_LOG; ZSTD_overrideCParams(&cParams, &CCtxParams->cParams); assert(!ZSTD_checkCParams(cParams)); /* srcSizeHint == 0 means 0 */ - return ZSTD_adjustCParams_internal(cParams, srcSizeHint, dictSize, mode); + return ZSTD_adjustCParams_internal(cParams, srcSizeHint, dictSize, mode, CCtxParams->useRowMatchFinder); } static size_t ZSTD_sizeof_matchState(const ZSTD_compressionParameters* const cParams, - const ZSTD_paramSwitch_e useRowMatchFinder, - const U32 enableDedicatedDictSearch, + const ZSTD_ParamSwitch_e useRowMatchFinder, + const int enableDedicatedDictSearch, const U32 forCCtx) { /* chain table size should be 0 for fast or row-hash strategies */ @@ -1363,14 +1605,14 @@ ZSTD_sizeof_matchState(const ZSTD_compressionParameters* const cParams, + hSize * sizeof(U32) + h3Size * sizeof(U32); size_t const optPotentialSpace = - ZSTD_cwksp_aligned_alloc_size((MaxML+1) * sizeof(U32)) - + ZSTD_cwksp_aligned_alloc_size((MaxLL+1) * sizeof(U32)) - + ZSTD_cwksp_aligned_alloc_size((MaxOff+1) * sizeof(U32)) - + ZSTD_cwksp_aligned_alloc_size((1<strategy, useRowMatchFinder) - ? ZSTD_cwksp_aligned_alloc_size(hSize*sizeof(U16)) + ? ZSTD_cwksp_aligned64_alloc_size(hSize) : 0; size_t const optSpace = (forCCtx && (cParams->strategy >= ZSTD_btopt)) ? optPotentialSpace @@ -1386,30 +1628,38 @@ ZSTD_sizeof_matchState(const ZSTD_compressionParameters* const cParams, return tableSpace + optSpace + slackSpace + lazyAdditionalSpace; } +/* Helper function for calculating memory requirements. + * Gives a tighter bound than ZSTD_sequenceBound() by taking minMatch into account. */ +static size_t ZSTD_maxNbSeq(size_t blockSize, unsigned minMatch, int useSequenceProducer) { + U32 const divider = (minMatch==3 || useSequenceProducer) ? 3 : 4; + return blockSize / divider; +} + static size_t ZSTD_estimateCCtxSize_usingCCtxParams_internal( const ZSTD_compressionParameters* cParams, const ldmParams_t* ldmParams, const int isStatic, - const ZSTD_paramSwitch_e useRowMatchFinder, + const ZSTD_ParamSwitch_e useRowMatchFinder, const size_t buffInSize, const size_t buffOutSize, - const U64 pledgedSrcSize) + const U64 pledgedSrcSize, + int useSequenceProducer, + size_t maxBlockSize) { size_t const windowSize = (size_t) BOUNDED(1ULL, 1ULL << cParams->windowLog, pledgedSrcSize); - size_t const blockSize = MIN(ZSTD_BLOCKSIZE_MAX, windowSize); - U32 const divider = (cParams->minMatch==3) ? 3 : 4; - size_t const maxNbSeq = blockSize / divider; + size_t const blockSize = MIN(ZSTD_resolveMaxBlockSize(maxBlockSize), windowSize); + size_t const maxNbSeq = ZSTD_maxNbSeq(blockSize, cParams->minMatch, useSequenceProducer); size_t const tokenSpace = ZSTD_cwksp_alloc_size(WILDCOPY_OVERLENGTH + blockSize) - + ZSTD_cwksp_aligned_alloc_size(maxNbSeq * sizeof(seqDef)) + + ZSTD_cwksp_aligned64_alloc_size(maxNbSeq * sizeof(SeqDef)) + 3 * ZSTD_cwksp_alloc_size(maxNbSeq * sizeof(BYTE)); - size_t const entropySpace = ZSTD_cwksp_alloc_size(ENTROPY_WORKSPACE_SIZE); + size_t const tmpWorkSpace = ZSTD_cwksp_alloc_size(TMP_WORKSPACE_SIZE); size_t const blockStateSpace = 2 * ZSTD_cwksp_alloc_size(sizeof(ZSTD_compressedBlockState_t)); size_t const matchStateSize = ZSTD_sizeof_matchState(cParams, useRowMatchFinder, /* enableDedicatedDictSearch */ 0, /* forCCtx */ 1); size_t const ldmSpace = ZSTD_ldm_getTableSize(*ldmParams); size_t const maxNbLdmSeq = ZSTD_ldm_getMaxNbSeq(*ldmParams, blockSize); size_t const ldmSeqSpace = ldmParams->enableLdm == ZSTD_ps_enable ? - ZSTD_cwksp_aligned_alloc_size(maxNbLdmSeq * sizeof(rawSeq)) : 0; + ZSTD_cwksp_aligned64_alloc_size(maxNbLdmSeq * sizeof(rawSeq)) : 0; size_t const bufferSpace = ZSTD_cwksp_alloc_size(buffInSize) @@ -1417,15 +1667,21 @@ static size_t ZSTD_estimateCCtxSize_usingCCtxParams_internal( size_t const cctxSpace = isStatic ? ZSTD_cwksp_alloc_size(sizeof(ZSTD_CCtx)) : 0; + size_t const maxNbExternalSeq = ZSTD_sequenceBound(blockSize); + size_t const externalSeqSpace = useSequenceProducer + ? ZSTD_cwksp_aligned64_alloc_size(maxNbExternalSeq * sizeof(ZSTD_Sequence)) + : 0; + size_t const neededSpace = cctxSpace + - entropySpace + + tmpWorkSpace + blockStateSpace + ldmSpace + ldmSeqSpace + matchStateSize + tokenSpace + - bufferSpace; + bufferSpace + + externalSeqSpace; DEBUGLOG(5, "estimate workspace : %u", (U32)neededSpace); return neededSpace; @@ -1435,7 +1691,7 @@ size_t ZSTD_estimateCCtxSize_usingCCtxParams(const ZSTD_CCtx_params* params) { ZSTD_compressionParameters const cParams = ZSTD_getCParamsFromCCtxParams(params, ZSTD_CONTENTSIZE_UNKNOWN, 0, ZSTD_cpm_noAttachDict); - ZSTD_paramSwitch_e const useRowMatchFinder = ZSTD_resolveRowMatchFinderMode(params->useRowMatchFinder, + ZSTD_ParamSwitch_e const useRowMatchFinder = ZSTD_resolveRowMatchFinderMode(params->useRowMatchFinder, &cParams); RETURN_ERROR_IF(params->nbWorkers > 0, GENERIC, "Estimate CCtx size is supported for single-threaded compression only."); @@ -1443,7 +1699,7 @@ size_t ZSTD_estimateCCtxSize_usingCCtxParams(const ZSTD_CCtx_params* params) * be needed. However, we still allocate two 0-sized buffers, which can * take space under ASAN. */ return ZSTD_estimateCCtxSize_usingCCtxParams_internal( - &cParams, ¶ms->ldmParams, 1, useRowMatchFinder, 0, 0, ZSTD_CONTENTSIZE_UNKNOWN); + &cParams, ¶ms->ldmParams, 1, useRowMatchFinder, 0, 0, ZSTD_CONTENTSIZE_UNKNOWN, ZSTD_hasExtSeqProd(params), params->maxBlockSize); } size_t ZSTD_estimateCCtxSize_usingCParams(ZSTD_compressionParameters cParams) @@ -1493,18 +1749,18 @@ size_t ZSTD_estimateCStreamSize_usingCCtxParams(const ZSTD_CCtx_params* params) RETURN_ERROR_IF(params->nbWorkers > 0, GENERIC, "Estimate CCtx size is supported for single-threaded compression only."); { ZSTD_compressionParameters const cParams = ZSTD_getCParamsFromCCtxParams(params, ZSTD_CONTENTSIZE_UNKNOWN, 0, ZSTD_cpm_noAttachDict); - size_t const blockSize = MIN(ZSTD_BLOCKSIZE_MAX, (size_t)1 << cParams.windowLog); + size_t const blockSize = MIN(ZSTD_resolveMaxBlockSize(params->maxBlockSize), (size_t)1 << cParams.windowLog); size_t const inBuffSize = (params->inBufferMode == ZSTD_bm_buffered) ? ((size_t)1 << cParams.windowLog) + blockSize : 0; size_t const outBuffSize = (params->outBufferMode == ZSTD_bm_buffered) ? ZSTD_compressBound(blockSize) + 1 : 0; - ZSTD_paramSwitch_e const useRowMatchFinder = ZSTD_resolveRowMatchFinderMode(params->useRowMatchFinder, ¶ms->cParams); + ZSTD_ParamSwitch_e const useRowMatchFinder = ZSTD_resolveRowMatchFinderMode(params->useRowMatchFinder, ¶ms->cParams); return ZSTD_estimateCCtxSize_usingCCtxParams_internal( &cParams, ¶ms->ldmParams, 1, useRowMatchFinder, inBuffSize, outBuffSize, - ZSTD_CONTENTSIZE_UNKNOWN); + ZSTD_CONTENTSIZE_UNKNOWN, ZSTD_hasExtSeqProd(params), params->maxBlockSize); } } @@ -1600,7 +1856,7 @@ void ZSTD_reset_compressedBlockState(ZSTD_compressedBlockState_t* bs) * Invalidate all the matches in the match finder tables. * Requires nextSrc and base to be set (can be NULL). */ -static void ZSTD_invalidateMatchState(ZSTD_matchState_t* ms) +static void ZSTD_invalidateMatchState(ZSTD_MatchState_t* ms) { ZSTD_window_clear(&ms->window); @@ -1637,12 +1893,25 @@ typedef enum { ZSTD_resetTarget_CCtx } ZSTD_resetTarget_e; +/* Mixes bits in a 64 bits in a value, based on XXH3_rrmxmx */ +static U64 ZSTD_bitmix(U64 val, U64 len) { + val ^= ZSTD_rotateRight_U64(val, 49) ^ ZSTD_rotateRight_U64(val, 24); + val *= 0x9FB21C651E98DF25ULL; + val ^= (val >> 35) + len ; + val *= 0x9FB21C651E98DF25ULL; + return val ^ (val >> 28); +} + +/* Mixes in the hashSalt and hashSaltEntropy to create a new hashSalt */ +static void ZSTD_advanceHashSalt(ZSTD_MatchState_t* ms) { + ms->hashSalt = ZSTD_bitmix(ms->hashSalt, 8) ^ ZSTD_bitmix((U64) ms->hashSaltEntropy, 4); +} static size_t -ZSTD_reset_matchState(ZSTD_matchState_t* ms, +ZSTD_reset_matchState(ZSTD_MatchState_t* ms, ZSTD_cwksp* ws, const ZSTD_compressionParameters* cParams, - const ZSTD_paramSwitch_e useRowMatchFinder, + const ZSTD_ParamSwitch_e useRowMatchFinder, const ZSTD_compResetPolicy_e crp, const ZSTD_indexResetPolicy_e forceResetIndex, const ZSTD_resetTarget_e forWho) @@ -1664,6 +1933,7 @@ ZSTD_reset_matchState(ZSTD_matchState_t* ms, } ms->hashLog3 = hashLog3; + ms->lazySkipping = 0; ZSTD_invalidateMatchState(ms); @@ -1685,22 +1955,19 @@ ZSTD_reset_matchState(ZSTD_matchState_t* ms, ZSTD_cwksp_clean_tables(ws); } - /* opt parser space */ - if ((forWho == ZSTD_resetTarget_CCtx) && (cParams->strategy >= ZSTD_btopt)) { - DEBUGLOG(4, "reserving optimal parser space"); - ms->opt.litFreq = (unsigned*)ZSTD_cwksp_reserve_aligned(ws, (1<opt.litLengthFreq = (unsigned*)ZSTD_cwksp_reserve_aligned(ws, (MaxLL+1) * sizeof(unsigned)); - ms->opt.matchLengthFreq = (unsigned*)ZSTD_cwksp_reserve_aligned(ws, (MaxML+1) * sizeof(unsigned)); - ms->opt.offCodeFreq = (unsigned*)ZSTD_cwksp_reserve_aligned(ws, (MaxOff+1) * sizeof(unsigned)); - ms->opt.matchTable = (ZSTD_match_t*)ZSTD_cwksp_reserve_aligned(ws, (ZSTD_OPT_NUM+1) * sizeof(ZSTD_match_t)); - ms->opt.priceTable = (ZSTD_optimal_t*)ZSTD_cwksp_reserve_aligned(ws, (ZSTD_OPT_NUM+1) * sizeof(ZSTD_optimal_t)); - } - if (ZSTD_rowMatchFinderUsed(cParams->strategy, useRowMatchFinder)) { - { /* Row match finder needs an additional table of hashes ("tags") */ - size_t const tagTableSize = hSize*sizeof(U16); - ms->tagTable = (U16*)ZSTD_cwksp_reserve_aligned(ws, tagTableSize); - if (ms->tagTable) ZSTD_memset(ms->tagTable, 0, tagTableSize); + /* Row match finder needs an additional table of hashes ("tags") */ + size_t const tagTableSize = hSize; + /* We want to generate a new salt in case we reset a Cctx, but we always want to use + * 0 when we reset a Cdict */ + if(forWho == ZSTD_resetTarget_CCtx) { + ms->tagTable = (BYTE*) ZSTD_cwksp_reserve_aligned_init_once(ws, tagTableSize); + ZSTD_advanceHashSalt(ms); + } else { + /* When we are not salting we want to always memset the memory */ + ms->tagTable = (BYTE*) ZSTD_cwksp_reserve_aligned64(ws, tagTableSize); + ZSTD_memset(ms->tagTable, 0, tagTableSize); + ms->hashSalt = 0; } { /* Switch to 32-entry rows if searchLog is 5 (or more) */ U32 const rowLog = BOUNDED(4, cParams->searchLog, 6); @@ -1709,6 +1976,17 @@ ZSTD_reset_matchState(ZSTD_matchState_t* ms, } } + /* opt parser space */ + if ((forWho == ZSTD_resetTarget_CCtx) && (cParams->strategy >= ZSTD_btopt)) { + DEBUGLOG(4, "reserving optimal parser space"); + ms->opt.litFreq = (unsigned*)ZSTD_cwksp_reserve_aligned64(ws, (1<opt.litLengthFreq = (unsigned*)ZSTD_cwksp_reserve_aligned64(ws, (MaxLL+1) * sizeof(unsigned)); + ms->opt.matchLengthFreq = (unsigned*)ZSTD_cwksp_reserve_aligned64(ws, (MaxML+1) * sizeof(unsigned)); + ms->opt.offCodeFreq = (unsigned*)ZSTD_cwksp_reserve_aligned64(ws, (MaxOff+1) * sizeof(unsigned)); + ms->opt.matchTable = (ZSTD_match_t*)ZSTD_cwksp_reserve_aligned64(ws, ZSTD_OPT_SIZE * sizeof(ZSTD_match_t)); + ms->opt.priceTable = (ZSTD_optimal_t*)ZSTD_cwksp_reserve_aligned64(ws, ZSTD_OPT_SIZE * sizeof(ZSTD_optimal_t)); + } + ms->cParams = *cParams; RETURN_ERROR_IF(ZSTD_cwksp_reserve_failed(ws), memory_allocation, @@ -1754,7 +2032,7 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, { ZSTD_cwksp* const ws = &zc->workspace; DEBUGLOG(4, "ZSTD_resetCCtx_internal: pledgedSrcSize=%u, wlog=%u, useRowMatchFinder=%d useBlockSplitter=%d", - (U32)pledgedSrcSize, params->cParams.windowLog, (int)params->useRowMatchFinder, (int)params->useBlockSplitter); + (U32)pledgedSrcSize, params->cParams.windowLog, (int)params->useRowMatchFinder, (int)params->postBlockSplitter); assert(!ZSTD_isError(ZSTD_checkCParams(params->cParams))); zc->isFirstBlock = 1; @@ -1766,8 +2044,9 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, params = &zc->appliedParams; assert(params->useRowMatchFinder != ZSTD_ps_auto); - assert(params->useBlockSplitter != ZSTD_ps_auto); + assert(params->postBlockSplitter != ZSTD_ps_auto); assert(params->ldmParams.enableLdm != ZSTD_ps_auto); + assert(params->maxBlockSize != 0); if (params->ldmParams.enableLdm == ZSTD_ps_enable) { /* Adjust long distance matching parameters */ ZSTD_ldm_adjustParameters(&zc->appliedParams.ldmParams, ¶ms->cParams); @@ -1776,9 +2055,8 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, } { size_t const windowSize = MAX(1, (size_t)MIN(((U64)1 << params->cParams.windowLog), pledgedSrcSize)); - size_t const blockSize = MIN(ZSTD_BLOCKSIZE_MAX, windowSize); - U32 const divider = (params->cParams.minMatch==3) ? 3 : 4; - size_t const maxNbSeq = blockSize / divider; + size_t const blockSize = MIN(params->maxBlockSize, windowSize); + size_t const maxNbSeq = ZSTD_maxNbSeq(blockSize, params->cParams.minMatch, ZSTD_hasExtSeqProd(params)); size_t const buffOutSize = (zbuff == ZSTDb_buffered && params->outBufferMode == ZSTD_bm_buffered) ? ZSTD_compressBound(blockSize) + 1 : 0; @@ -1795,8 +2073,7 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, size_t const neededSpace = ZSTD_estimateCCtxSize_usingCCtxParams_internal( ¶ms->cParams, ¶ms->ldmParams, zc->staticSize != 0, params->useRowMatchFinder, - buffInSize, buffOutSize, pledgedSrcSize); - int resizeWorkspace; + buffInSize, buffOutSize, pledgedSrcSize, ZSTD_hasExtSeqProd(params), params->maxBlockSize); FORWARD_IF_ERROR(neededSpace, "cctx size estimate failed!"); @@ -1805,7 +2082,7 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, { /* Check if workspace is large enough, alloc a new one if needed */ int const workspaceTooSmall = ZSTD_cwksp_sizeof(ws) < neededSpace; int const workspaceWasteful = ZSTD_cwksp_check_wasteful(ws, neededSpace); - resizeWorkspace = workspaceTooSmall || workspaceWasteful; + int resizeWorkspace = workspaceTooSmall || workspaceWasteful; DEBUGLOG(4, "Need %zu B workspace", neededSpace); DEBUGLOG(4, "windowSize: %zu - blockSize: %zu", windowSize, blockSize); @@ -1823,21 +2100,23 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, DEBUGLOG(5, "reserving object space"); /* Statically sized space. - * entropyWorkspace never moves, + * tmpWorkspace never moves, * though prev/next block swap places */ assert(ZSTD_cwksp_check_available(ws, 2 * sizeof(ZSTD_compressedBlockState_t))); zc->blockState.prevCBlock = (ZSTD_compressedBlockState_t*) ZSTD_cwksp_reserve_object(ws, sizeof(ZSTD_compressedBlockState_t)); RETURN_ERROR_IF(zc->blockState.prevCBlock == NULL, memory_allocation, "couldn't allocate prevCBlock"); zc->blockState.nextCBlock = (ZSTD_compressedBlockState_t*) ZSTD_cwksp_reserve_object(ws, sizeof(ZSTD_compressedBlockState_t)); RETURN_ERROR_IF(zc->blockState.nextCBlock == NULL, memory_allocation, "couldn't allocate nextCBlock"); - zc->entropyWorkspace = (U32*) ZSTD_cwksp_reserve_object(ws, ENTROPY_WORKSPACE_SIZE); - RETURN_ERROR_IF(zc->entropyWorkspace == NULL, memory_allocation, "couldn't allocate entropyWorkspace"); + zc->tmpWorkspace = ZSTD_cwksp_reserve_object(ws, TMP_WORKSPACE_SIZE); + RETURN_ERROR_IF(zc->tmpWorkspace == NULL, memory_allocation, "couldn't allocate tmpWorkspace"); + zc->tmpWkspSize = TMP_WORKSPACE_SIZE; } } ZSTD_cwksp_clear(ws); /* init params */ zc->blockState.matchState.cParams = params->cParams; + zc->blockState.matchState.prefetchCDictTables = params->prefetchCDictTables == ZSTD_ps_enable; zc->pledgedSrcSizePlusOne = pledgedSrcSize+1; zc->consumedSrcSize = 0; zc->producedCSize = 0; @@ -1845,7 +2124,7 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, zc->appliedParams.fParams.contentSizeFlag = 0; DEBUGLOG(4, "pledged content size : %u ; flag : %u", (unsigned)pledgedSrcSize, zc->appliedParams.fParams.contentSizeFlag); - zc->blockSize = blockSize; + zc->blockSizeMax = blockSize; xxh64_reset(&zc->xxhState, 0); zc->stage = ZSTDcs_init; @@ -1854,13 +2133,46 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, ZSTD_reset_compressedBlockState(zc->blockState.prevCBlock); + FORWARD_IF_ERROR(ZSTD_reset_matchState( + &zc->blockState.matchState, + ws, + ¶ms->cParams, + params->useRowMatchFinder, + crp, + needsIndexReset, + ZSTD_resetTarget_CCtx), ""); + + zc->seqStore.sequencesStart = (SeqDef*)ZSTD_cwksp_reserve_aligned64(ws, maxNbSeq * sizeof(SeqDef)); + + /* ldm hash table */ + if (params->ldmParams.enableLdm == ZSTD_ps_enable) { + /* TODO: avoid memset? */ + size_t const ldmHSize = ((size_t)1) << params->ldmParams.hashLog; + zc->ldmState.hashTable = (ldmEntry_t*)ZSTD_cwksp_reserve_aligned64(ws, ldmHSize * sizeof(ldmEntry_t)); + ZSTD_memset(zc->ldmState.hashTable, 0, ldmHSize * sizeof(ldmEntry_t)); + zc->ldmSequences = (rawSeq*)ZSTD_cwksp_reserve_aligned64(ws, maxNbLdmSeq * sizeof(rawSeq)); + zc->maxNbLdmSequences = maxNbLdmSeq; + + ZSTD_window_init(&zc->ldmState.window); + zc->ldmState.loadedDictEnd = 0; + } + + /* reserve space for block-level external sequences */ + if (ZSTD_hasExtSeqProd(params)) { + size_t const maxNbExternalSeq = ZSTD_sequenceBound(blockSize); + zc->extSeqBufCapacity = maxNbExternalSeq; + zc->extSeqBuf = + (ZSTD_Sequence*)ZSTD_cwksp_reserve_aligned64(ws, maxNbExternalSeq * sizeof(ZSTD_Sequence)); + } + + /* buffers */ + /* ZSTD_wildcopy() is used to copy into the literals buffer, * so we have to oversize the buffer by WILDCOPY_OVERLENGTH bytes. */ zc->seqStore.litStart = ZSTD_cwksp_reserve_buffer(ws, blockSize + WILDCOPY_OVERLENGTH); zc->seqStore.maxNbLit = blockSize; - /* buffers */ zc->bufferedPolicy = zbuff; zc->inBuffSize = buffInSize; zc->inBuff = (char*)ZSTD_cwksp_reserve_buffer(ws, buffInSize); @@ -1883,32 +2195,9 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, zc->seqStore.llCode = ZSTD_cwksp_reserve_buffer(ws, maxNbSeq * sizeof(BYTE)); zc->seqStore.mlCode = ZSTD_cwksp_reserve_buffer(ws, maxNbSeq * sizeof(BYTE)); zc->seqStore.ofCode = ZSTD_cwksp_reserve_buffer(ws, maxNbSeq * sizeof(BYTE)); - zc->seqStore.sequencesStart = (seqDef*)ZSTD_cwksp_reserve_aligned(ws, maxNbSeq * sizeof(seqDef)); - - FORWARD_IF_ERROR(ZSTD_reset_matchState( - &zc->blockState.matchState, - ws, - ¶ms->cParams, - params->useRowMatchFinder, - crp, - needsIndexReset, - ZSTD_resetTarget_CCtx), ""); - - /* ldm hash table */ - if (params->ldmParams.enableLdm == ZSTD_ps_enable) { - /* TODO: avoid memset? */ - size_t const ldmHSize = ((size_t)1) << params->ldmParams.hashLog; - zc->ldmState.hashTable = (ldmEntry_t*)ZSTD_cwksp_reserve_aligned(ws, ldmHSize * sizeof(ldmEntry_t)); - ZSTD_memset(zc->ldmState.hashTable, 0, ldmHSize * sizeof(ldmEntry_t)); - zc->ldmSequences = (rawSeq*)ZSTD_cwksp_reserve_aligned(ws, maxNbLdmSeq * sizeof(rawSeq)); - zc->maxNbLdmSequences = maxNbLdmSeq; - - ZSTD_window_init(&zc->ldmState.window); - zc->ldmState.loadedDictEnd = 0; - } DEBUGLOG(3, "wksp: finished allocating, %zd bytes remain available", ZSTD_cwksp_available_space(ws)); - assert(ZSTD_cwksp_estimated_space_within_bounds(ws, neededSpace, resizeWorkspace)); + assert(ZSTD_cwksp_estimated_space_within_bounds(ws, neededSpace)); zc->initialized = 1; @@ -1980,7 +2269,8 @@ ZSTD_resetCCtx_byAttachingCDict(ZSTD_CCtx* cctx, } params.cParams = ZSTD_adjustCParams_internal(adjusted_cdict_cParams, pledgedSrcSize, - cdict->dictContentSize, ZSTD_cpm_attachDict); + cdict->dictContentSize, ZSTD_cpm_attachDict, + params.useRowMatchFinder); params.cParams.windowLog = windowLog; params.useRowMatchFinder = cdict->useRowMatchFinder; /* cdict overrides */ FORWARD_IF_ERROR(ZSTD_resetCCtx_internal(cctx, ¶ms, pledgedSrcSize, @@ -2019,6 +2309,22 @@ ZSTD_resetCCtx_byAttachingCDict(ZSTD_CCtx* cctx, return 0; } +static void ZSTD_copyCDictTableIntoCCtx(U32* dst, U32 const* src, size_t tableSize, + ZSTD_compressionParameters const* cParams) { + if (ZSTD_CDictIndicesAreTagged(cParams)){ + /* Remove tags from the CDict table if they are present. + * See docs on "short cache" in zstd_compress_internal.h for context. */ + size_t i; + for (i = 0; i < tableSize; i++) { + U32 const taggedIndex = src[i]; + U32 const index = taggedIndex >> ZSTD_SHORT_CACHE_TAG_BITS; + dst[i] = index; + } + } else { + ZSTD_memcpy(dst, src, tableSize * sizeof(U32)); + } +} + static size_t ZSTD_resetCCtx_byCopyingCDict(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict, ZSTD_CCtx_params params, @@ -2054,26 +2360,29 @@ static size_t ZSTD_resetCCtx_byCopyingCDict(ZSTD_CCtx* cctx, : 0; size_t const hSize = (size_t)1 << cdict_cParams->hashLog; - ZSTD_memcpy(cctx->blockState.matchState.hashTable, - cdict->matchState.hashTable, - hSize * sizeof(U32)); + ZSTD_copyCDictTableIntoCCtx(cctx->blockState.matchState.hashTable, + cdict->matchState.hashTable, + hSize, cdict_cParams); + /* Do not copy cdict's chainTable if cctx has parameters such that it would not use chainTable */ if (ZSTD_allocateChainTable(cctx->appliedParams.cParams.strategy, cctx->appliedParams.useRowMatchFinder, 0 /* forDDSDict */)) { - ZSTD_memcpy(cctx->blockState.matchState.chainTable, - cdict->matchState.chainTable, - chainSize * sizeof(U32)); + ZSTD_copyCDictTableIntoCCtx(cctx->blockState.matchState.chainTable, + cdict->matchState.chainTable, + chainSize, cdict_cParams); } /* copy tag table */ if (ZSTD_rowMatchFinderUsed(cdict_cParams->strategy, cdict->useRowMatchFinder)) { - size_t const tagTableSize = hSize*sizeof(U16); + size_t const tagTableSize = hSize; ZSTD_memcpy(cctx->blockState.matchState.tagTable, - cdict->matchState.tagTable, - tagTableSize); + cdict->matchState.tagTable, + tagTableSize); + cctx->blockState.matchState.hashSalt = cdict->matchState.hashSalt; } } /* Zero the hashTable3, since the cdict never fills it */ - { int const h3log = cctx->blockState.matchState.hashLog3; + assert(cctx->blockState.matchState.hashLog3 <= 31); + { U32 const h3log = cctx->blockState.matchState.hashLog3; size_t const h3Size = h3log ? ((size_t)1 << h3log) : 0; assert(cdict->matchState.hashLog3 == 0); ZSTD_memset(cctx->blockState.matchState.hashTable3, 0, h3Size * sizeof(U32)); @@ -2082,8 +2391,8 @@ static size_t ZSTD_resetCCtx_byCopyingCDict(ZSTD_CCtx* cctx, ZSTD_cwksp_mark_tables_clean(&cctx->workspace); /* copy dictionary offsets */ - { ZSTD_matchState_t const* srcMatchState = &cdict->matchState; - ZSTD_matchState_t* dstMatchState = &cctx->blockState.matchState; + { ZSTD_MatchState_t const* srcMatchState = &cdict->matchState; + ZSTD_MatchState_t* dstMatchState = &cctx->blockState.matchState; dstMatchState->window = srcMatchState->window; dstMatchState->nextToUpdate = srcMatchState->nextToUpdate; dstMatchState->loadedDictEnd= srcMatchState->loadedDictEnd; @@ -2141,12 +2450,13 @@ static size_t ZSTD_copyCCtx_internal(ZSTD_CCtx* dstCCtx, /* Copy only compression parameters related to tables. */ params.cParams = srcCCtx->appliedParams.cParams; assert(srcCCtx->appliedParams.useRowMatchFinder != ZSTD_ps_auto); - assert(srcCCtx->appliedParams.useBlockSplitter != ZSTD_ps_auto); + assert(srcCCtx->appliedParams.postBlockSplitter != ZSTD_ps_auto); assert(srcCCtx->appliedParams.ldmParams.enableLdm != ZSTD_ps_auto); params.useRowMatchFinder = srcCCtx->appliedParams.useRowMatchFinder; - params.useBlockSplitter = srcCCtx->appliedParams.useBlockSplitter; + params.postBlockSplitter = srcCCtx->appliedParams.postBlockSplitter; params.ldmParams = srcCCtx->appliedParams.ldmParams; params.fParams = fParams; + params.maxBlockSize = srcCCtx->appliedParams.maxBlockSize; ZSTD_resetCCtx_internal(dstCCtx, ¶ms, pledgedSrcSize, /* loadedDictSize */ 0, ZSTDcrp_leaveDirty, zbuff); @@ -2166,7 +2476,7 @@ static size_t ZSTD_copyCCtx_internal(ZSTD_CCtx* dstCCtx, ? ((size_t)1 << srcCCtx->appliedParams.cParams.chainLog) : 0; size_t const hSize = (size_t)1 << srcCCtx->appliedParams.cParams.hashLog; - int const h3log = srcCCtx->blockState.matchState.hashLog3; + U32 const h3log = srcCCtx->blockState.matchState.hashLog3; size_t const h3Size = h3log ? ((size_t)1 << h3log) : 0; ZSTD_memcpy(dstCCtx->blockState.matchState.hashTable, @@ -2184,8 +2494,8 @@ static size_t ZSTD_copyCCtx_internal(ZSTD_CCtx* dstCCtx, /* copy dictionary offsets */ { - const ZSTD_matchState_t* srcMatchState = &srcCCtx->blockState.matchState; - ZSTD_matchState_t* dstMatchState = &dstCCtx->blockState.matchState; + const ZSTD_MatchState_t* srcMatchState = &srcCCtx->blockState.matchState; + ZSTD_MatchState_t* dstMatchState = &dstCCtx->blockState.matchState; dstMatchState->window = srcMatchState->window; dstMatchState->nextToUpdate = srcMatchState->nextToUpdate; dstMatchState->loadedDictEnd= srcMatchState->loadedDictEnd; @@ -2234,7 +2544,7 @@ ZSTD_reduceTable_internal (U32* const table, U32 const size, U32 const reducerVa /* Protect special index values < ZSTD_WINDOW_START_INDEX. */ U32 const reducerThreshold = reducerValue + ZSTD_WINDOW_START_INDEX; assert((size & (ZSTD_ROWSIZE-1)) == 0); /* multiple of ZSTD_ROWSIZE */ - assert(size < (1U<<31)); /* can be casted to int */ + assert(size < (1U<<31)); /* can be cast to int */ for (rowNb=0 ; rowNb < nbRows ; rowNb++) { @@ -2267,7 +2577,7 @@ static void ZSTD_reduceTable_btlazy2(U32* const table, U32 const size, U32 const /*! ZSTD_reduceIndex() : * rescale all indexes to avoid future overflow (indexes are U32) */ -static void ZSTD_reduceIndex (ZSTD_matchState_t* ms, ZSTD_CCtx_params const* params, const U32 reducerValue) +static void ZSTD_reduceIndex (ZSTD_MatchState_t* ms, ZSTD_CCtx_params const* params, const U32 reducerValue) { { U32 const hSize = (U32)1 << params->cParams.hashLog; ZSTD_reduceTable(ms->hashTable, hSize, reducerValue); @@ -2294,26 +2604,32 @@ static void ZSTD_reduceIndex (ZSTD_matchState_t* ms, ZSTD_CCtx_params const* par /* See doc/zstd_compression_format.md for detailed format description */ -void ZSTD_seqToCodes(const seqStore_t* seqStorePtr) +int ZSTD_seqToCodes(const SeqStore_t* seqStorePtr) { - const seqDef* const sequences = seqStorePtr->sequencesStart; + const SeqDef* const sequences = seqStorePtr->sequencesStart; BYTE* const llCodeTable = seqStorePtr->llCode; BYTE* const ofCodeTable = seqStorePtr->ofCode; BYTE* const mlCodeTable = seqStorePtr->mlCode; U32 const nbSeq = (U32)(seqStorePtr->sequences - seqStorePtr->sequencesStart); U32 u; + int longOffsets = 0; assert(nbSeq <= seqStorePtr->maxNbSeq); for (u=0; u= STREAM_ACCUMULATOR_MIN)); + if (MEM_32bits() && ofCode >= STREAM_ACCUMULATOR_MIN) + longOffsets = 1; } if (seqStorePtr->longLengthType==ZSTD_llt_literalLength) llCodeTable[seqStorePtr->longLengthPos] = MaxLL; if (seqStorePtr->longLengthType==ZSTD_llt_matchLength) mlCodeTable[seqStorePtr->longLengthPos] = MaxML; + return longOffsets; } /* ZSTD_useTargetCBlockSize(): @@ -2333,9 +2649,9 @@ static int ZSTD_useTargetCBlockSize(const ZSTD_CCtx_params* cctxParams) * Returns 1 if true, 0 otherwise. */ static int ZSTD_blockSplitterEnabled(ZSTD_CCtx_params* cctxParams) { - DEBUGLOG(5, "ZSTD_blockSplitterEnabled (useBlockSplitter=%d)", cctxParams->useBlockSplitter); - assert(cctxParams->useBlockSplitter != ZSTD_ps_auto); - return (cctxParams->useBlockSplitter == ZSTD_ps_enable); + DEBUGLOG(5, "ZSTD_blockSplitterEnabled (postBlockSplitter=%d)", cctxParams->postBlockSplitter); + assert(cctxParams->postBlockSplitter != ZSTD_ps_auto); + return (cctxParams->postBlockSplitter == ZSTD_ps_enable); } /* Type returned by ZSTD_buildSequencesStatistics containing finalized symbol encoding types @@ -2347,6 +2663,7 @@ typedef struct { U32 MLtype; size_t size; size_t lastCountSize; /* Accounts for bug in 1.3.4. More detail in ZSTD_entropyCompressSeqStore_internal() */ + int longOffsets; } ZSTD_symbolEncodingTypeStats_t; /* ZSTD_buildSequencesStatistics(): @@ -2357,11 +2674,13 @@ typedef struct { * entropyWkspSize must be of size at least ENTROPY_WORKSPACE_SIZE - (MaxSeq + 1)*sizeof(U32) */ static ZSTD_symbolEncodingTypeStats_t -ZSTD_buildSequencesStatistics(seqStore_t* seqStorePtr, size_t nbSeq, - const ZSTD_fseCTables_t* prevEntropy, ZSTD_fseCTables_t* nextEntropy, - BYTE* dst, const BYTE* const dstEnd, - ZSTD_strategy strategy, unsigned* countWorkspace, - void* entropyWorkspace, size_t entropyWkspSize) { +ZSTD_buildSequencesStatistics( + const SeqStore_t* seqStorePtr, size_t nbSeq, + const ZSTD_fseCTables_t* prevEntropy, ZSTD_fseCTables_t* nextEntropy, + BYTE* dst, const BYTE* const dstEnd, + ZSTD_strategy strategy, unsigned* countWorkspace, + void* entropyWorkspace, size_t entropyWkspSize) +{ BYTE* const ostart = dst; const BYTE* const oend = dstEnd; BYTE* op = ostart; @@ -2375,7 +2694,7 @@ ZSTD_buildSequencesStatistics(seqStore_t* seqStorePtr, size_t nbSeq, stats.lastCountSize = 0; /* convert length/distances into codes */ - ZSTD_seqToCodes(seqStorePtr); + stats.longOffsets = ZSTD_seqToCodes(seqStorePtr); assert(op <= oend); assert(nbSeq != 0); /* ZSTD_selectEncodingType() divides by nbSeq */ /* build CTable for Literal Lengths */ @@ -2392,7 +2711,7 @@ ZSTD_buildSequencesStatistics(seqStore_t* seqStorePtr, size_t nbSeq, assert(!(stats.LLtype < set_compressed && nextEntropy->litlength_repeatMode != FSE_repeat_none)); /* We don't copy tables */ { size_t const countSize = ZSTD_buildCTable( op, (size_t)(oend - op), - CTable_LitLength, LLFSELog, (symbolEncodingType_e)stats.LLtype, + CTable_LitLength, LLFSELog, (SymbolEncodingType_e)stats.LLtype, countWorkspace, max, llCodeTable, nbSeq, LL_defaultNorm, LL_defaultNormLog, MaxLL, prevEntropy->litlengthCTable, @@ -2413,7 +2732,7 @@ ZSTD_buildSequencesStatistics(seqStore_t* seqStorePtr, size_t nbSeq, size_t const mostFrequent = HIST_countFast_wksp( countWorkspace, &max, ofCodeTable, nbSeq, entropyWorkspace, entropyWkspSize); /* can't fail */ /* We can only use the basic table if max <= DefaultMaxOff, otherwise the offsets are too large */ - ZSTD_defaultPolicy_e const defaultPolicy = (max <= DefaultMaxOff) ? ZSTD_defaultAllowed : ZSTD_defaultDisallowed; + ZSTD_DefaultPolicy_e const defaultPolicy = (max <= DefaultMaxOff) ? ZSTD_defaultAllowed : ZSTD_defaultDisallowed; DEBUGLOG(5, "Building OF table"); nextEntropy->offcode_repeatMode = prevEntropy->offcode_repeatMode; stats.Offtype = ZSTD_selectEncodingType(&nextEntropy->offcode_repeatMode, @@ -2424,7 +2743,7 @@ ZSTD_buildSequencesStatistics(seqStore_t* seqStorePtr, size_t nbSeq, assert(!(stats.Offtype < set_compressed && nextEntropy->offcode_repeatMode != FSE_repeat_none)); /* We don't copy tables */ { size_t const countSize = ZSTD_buildCTable( op, (size_t)(oend - op), - CTable_OffsetBits, OffFSELog, (symbolEncodingType_e)stats.Offtype, + CTable_OffsetBits, OffFSELog, (SymbolEncodingType_e)stats.Offtype, countWorkspace, max, ofCodeTable, nbSeq, OF_defaultNorm, OF_defaultNormLog, DefaultMaxOff, prevEntropy->offcodeCTable, @@ -2454,7 +2773,7 @@ ZSTD_buildSequencesStatistics(seqStore_t* seqStorePtr, size_t nbSeq, assert(!(stats.MLtype < set_compressed && nextEntropy->matchlength_repeatMode != FSE_repeat_none)); /* We don't copy tables */ { size_t const countSize = ZSTD_buildCTable( op, (size_t)(oend - op), - CTable_MatchLength, MLFSELog, (symbolEncodingType_e)stats.MLtype, + CTable_MatchLength, MLFSELog, (SymbolEncodingType_e)stats.MLtype, countWorkspace, max, mlCodeTable, nbSeq, ML_defaultNorm, ML_defaultNormLog, MaxML, prevEntropy->matchlengthCTable, @@ -2480,22 +2799,23 @@ ZSTD_buildSequencesStatistics(seqStore_t* seqStorePtr, size_t nbSeq, */ #define SUSPECT_UNCOMPRESSIBLE_LITERAL_RATIO 20 MEM_STATIC size_t -ZSTD_entropyCompressSeqStore_internal(seqStore_t* seqStorePtr, - const ZSTD_entropyCTables_t* prevEntropy, - ZSTD_entropyCTables_t* nextEntropy, - const ZSTD_CCtx_params* cctxParams, - void* dst, size_t dstCapacity, - void* entropyWorkspace, size_t entropyWkspSize, - const int bmi2) +ZSTD_entropyCompressSeqStore_internal( + void* dst, size_t dstCapacity, + const void* literals, size_t litSize, + const SeqStore_t* seqStorePtr, + const ZSTD_entropyCTables_t* prevEntropy, + ZSTD_entropyCTables_t* nextEntropy, + const ZSTD_CCtx_params* cctxParams, + void* entropyWorkspace, size_t entropyWkspSize, + const int bmi2) { - const int longOffsets = cctxParams->cParams.windowLog > STREAM_ACCUMULATOR_MIN; ZSTD_strategy const strategy = cctxParams->cParams.strategy; unsigned* count = (unsigned*)entropyWorkspace; FSE_CTable* CTable_LitLength = nextEntropy->fse.litlengthCTable; FSE_CTable* CTable_OffsetBits = nextEntropy->fse.offcodeCTable; FSE_CTable* CTable_MatchLength = nextEntropy->fse.matchlengthCTable; - const seqDef* const sequences = seqStorePtr->sequencesStart; - const size_t nbSeq = seqStorePtr->sequences - seqStorePtr->sequencesStart; + const SeqDef* const sequences = seqStorePtr->sequencesStart; + const size_t nbSeq = (size_t)(seqStorePtr->sequences - seqStorePtr->sequencesStart); const BYTE* const ofCodeTable = seqStorePtr->ofCode; const BYTE* const llCodeTable = seqStorePtr->llCode; const BYTE* const mlCodeTable = seqStorePtr->mlCode; @@ -2503,29 +2823,28 @@ ZSTD_entropyCompressSeqStore_internal(seqStore_t* seqStorePtr, BYTE* const oend = ostart + dstCapacity; BYTE* op = ostart; size_t lastCountSize; + int longOffsets = 0; entropyWorkspace = count + (MaxSeq + 1); entropyWkspSize -= (MaxSeq + 1) * sizeof(*count); - DEBUGLOG(4, "ZSTD_entropyCompressSeqStore_internal (nbSeq=%zu)", nbSeq); + DEBUGLOG(5, "ZSTD_entropyCompressSeqStore_internal (nbSeq=%zu, dstCapacity=%zu)", nbSeq, dstCapacity); ZSTD_STATIC_ASSERT(HUF_WORKSPACE_SIZE >= (1<= HUF_WORKSPACE_SIZE); /* Compress literals */ - { const BYTE* const literals = seqStorePtr->litStart; - size_t const numSequences = seqStorePtr->sequences - seqStorePtr->sequencesStart; - size_t const numLiterals = seqStorePtr->lit - seqStorePtr->litStart; + { size_t const numSequences = (size_t)(seqStorePtr->sequences - seqStorePtr->sequencesStart); /* Base suspicion of uncompressibility on ratio of literals to sequences */ - unsigned const suspectUncompressible = (numSequences == 0) || (numLiterals / numSequences >= SUSPECT_UNCOMPRESSIBLE_LITERAL_RATIO); - size_t const litSize = (size_t)(seqStorePtr->lit - literals); + int const suspectUncompressible = (numSequences == 0) || (litSize / numSequences >= SUSPECT_UNCOMPRESSIBLE_LITERAL_RATIO); + size_t const cSize = ZSTD_compressLiterals( - &prevEntropy->huf, &nextEntropy->huf, - cctxParams->cParams.strategy, - ZSTD_literalsCompressionIsDisabled(cctxParams), op, dstCapacity, literals, litSize, entropyWorkspace, entropyWkspSize, - bmi2, suspectUncompressible); + &prevEntropy->huf, &nextEntropy->huf, + cctxParams->cParams.strategy, + ZSTD_literalsCompressionIsDisabled(cctxParams), + suspectUncompressible, bmi2); FORWARD_IF_ERROR(cSize, "ZSTD_compressLiterals failed"); assert(cSize <= dstCapacity); op += cSize; @@ -2551,11 +2870,10 @@ ZSTD_entropyCompressSeqStore_internal(seqStore_t* seqStorePtr, ZSTD_memcpy(&nextEntropy->fse, &prevEntropy->fse, sizeof(prevEntropy->fse)); return (size_t)(op - ostart); } - { - ZSTD_symbolEncodingTypeStats_t stats; - BYTE* seqHead = op++; + { BYTE* const seqHead = op++; /* build stats for sequences */ - stats = ZSTD_buildSequencesStatistics(seqStorePtr, nbSeq, + const ZSTD_symbolEncodingTypeStats_t stats = + ZSTD_buildSequencesStatistics(seqStorePtr, nbSeq, &prevEntropy->fse, &nextEntropy->fse, op, oend, strategy, count, @@ -2564,6 +2882,7 @@ ZSTD_entropyCompressSeqStore_internal(seqStore_t* seqStorePtr, *seqHead = (BYTE)((stats.LLtype<<6) + (stats.Offtype<<4) + (stats.MLtype<<2)); lastCountSize = stats.lastCountSize; op += stats.size; + longOffsets = stats.longOffsets; } { size_t const bitstreamSize = ZSTD_encodeSequences( @@ -2597,104 +2916,146 @@ ZSTD_entropyCompressSeqStore_internal(seqStore_t* seqStorePtr, return (size_t)(op - ostart); } -MEM_STATIC size_t -ZSTD_entropyCompressSeqStore(seqStore_t* seqStorePtr, - const ZSTD_entropyCTables_t* prevEntropy, - ZSTD_entropyCTables_t* nextEntropy, - const ZSTD_CCtx_params* cctxParams, - void* dst, size_t dstCapacity, - size_t srcSize, - void* entropyWorkspace, size_t entropyWkspSize, - int bmi2) +static size_t +ZSTD_entropyCompressSeqStore_wExtLitBuffer( + void* dst, size_t dstCapacity, + const void* literals, size_t litSize, + size_t blockSize, + const SeqStore_t* seqStorePtr, + const ZSTD_entropyCTables_t* prevEntropy, + ZSTD_entropyCTables_t* nextEntropy, + const ZSTD_CCtx_params* cctxParams, + void* entropyWorkspace, size_t entropyWkspSize, + int bmi2) { size_t const cSize = ZSTD_entropyCompressSeqStore_internal( - seqStorePtr, prevEntropy, nextEntropy, cctxParams, dst, dstCapacity, + literals, litSize, + seqStorePtr, prevEntropy, nextEntropy, cctxParams, entropyWorkspace, entropyWkspSize, bmi2); if (cSize == 0) return 0; /* When srcSize <= dstCapacity, there is enough space to write a raw uncompressed block. * Since we ran out of space, block must be not compressible, so fall back to raw uncompressed block. */ - if ((cSize == ERROR(dstSize_tooSmall)) & (srcSize <= dstCapacity)) + if ((cSize == ERROR(dstSize_tooSmall)) & (blockSize <= dstCapacity)) { + DEBUGLOG(4, "not enough dstCapacity (%zu) for ZSTD_entropyCompressSeqStore_internal()=> do not compress block", dstCapacity); return 0; /* block not compressed */ + } FORWARD_IF_ERROR(cSize, "ZSTD_entropyCompressSeqStore_internal failed"); /* Check compressibility */ - { size_t const maxCSize = srcSize - ZSTD_minGain(srcSize, cctxParams->cParams.strategy); + { size_t const maxCSize = blockSize - ZSTD_minGain(blockSize, cctxParams->cParams.strategy); if (cSize >= maxCSize) return 0; /* block not compressed */ } - DEBUGLOG(4, "ZSTD_entropyCompressSeqStore() cSize: %zu", cSize); + DEBUGLOG(5, "ZSTD_entropyCompressSeqStore() cSize: %zu", cSize); + /* libzstd decoder before > v1.5.4 is not compatible with compressed blocks of size ZSTD_BLOCKSIZE_MAX exactly. + * This restriction is indirectly already fulfilled by respecting ZSTD_minGain() condition above. + */ + assert(cSize < ZSTD_BLOCKSIZE_MAX); return cSize; } +static size_t +ZSTD_entropyCompressSeqStore( + const SeqStore_t* seqStorePtr, + const ZSTD_entropyCTables_t* prevEntropy, + ZSTD_entropyCTables_t* nextEntropy, + const ZSTD_CCtx_params* cctxParams, + void* dst, size_t dstCapacity, + size_t srcSize, + void* entropyWorkspace, size_t entropyWkspSize, + int bmi2) +{ + return ZSTD_entropyCompressSeqStore_wExtLitBuffer( + dst, dstCapacity, + seqStorePtr->litStart, (size_t)(seqStorePtr->lit - seqStorePtr->litStart), + srcSize, + seqStorePtr, + prevEntropy, nextEntropy, + cctxParams, + entropyWorkspace, entropyWkspSize, + bmi2); +} + /* ZSTD_selectBlockCompressor() : * Not static, but internal use only (used by long distance matcher) * assumption : strat is a valid strategy */ -ZSTD_blockCompressor ZSTD_selectBlockCompressor(ZSTD_strategy strat, ZSTD_paramSwitch_e useRowMatchFinder, ZSTD_dictMode_e dictMode) +ZSTD_BlockCompressor_f ZSTD_selectBlockCompressor(ZSTD_strategy strat, ZSTD_ParamSwitch_e useRowMatchFinder, ZSTD_dictMode_e dictMode) { - static const ZSTD_blockCompressor blockCompressor[4][ZSTD_STRATEGY_MAX+1] = { + static const ZSTD_BlockCompressor_f blockCompressor[4][ZSTD_STRATEGY_MAX+1] = { { ZSTD_compressBlock_fast /* default for 0 */, ZSTD_compressBlock_fast, - ZSTD_compressBlock_doubleFast, - ZSTD_compressBlock_greedy, - ZSTD_compressBlock_lazy, - ZSTD_compressBlock_lazy2, - ZSTD_compressBlock_btlazy2, - ZSTD_compressBlock_btopt, - ZSTD_compressBlock_btultra, - ZSTD_compressBlock_btultra2 }, + ZSTD_COMPRESSBLOCK_DOUBLEFAST, + ZSTD_COMPRESSBLOCK_GREEDY, + ZSTD_COMPRESSBLOCK_LAZY, + ZSTD_COMPRESSBLOCK_LAZY2, + ZSTD_COMPRESSBLOCK_BTLAZY2, + ZSTD_COMPRESSBLOCK_BTOPT, + ZSTD_COMPRESSBLOCK_BTULTRA, + ZSTD_COMPRESSBLOCK_BTULTRA2 + }, { ZSTD_compressBlock_fast_extDict /* default for 0 */, ZSTD_compressBlock_fast_extDict, - ZSTD_compressBlock_doubleFast_extDict, - ZSTD_compressBlock_greedy_extDict, - ZSTD_compressBlock_lazy_extDict, - ZSTD_compressBlock_lazy2_extDict, - ZSTD_compressBlock_btlazy2_extDict, - ZSTD_compressBlock_btopt_extDict, - ZSTD_compressBlock_btultra_extDict, - ZSTD_compressBlock_btultra_extDict }, + ZSTD_COMPRESSBLOCK_DOUBLEFAST_EXTDICT, + ZSTD_COMPRESSBLOCK_GREEDY_EXTDICT, + ZSTD_COMPRESSBLOCK_LAZY_EXTDICT, + ZSTD_COMPRESSBLOCK_LAZY2_EXTDICT, + ZSTD_COMPRESSBLOCK_BTLAZY2_EXTDICT, + ZSTD_COMPRESSBLOCK_BTOPT_EXTDICT, + ZSTD_COMPRESSBLOCK_BTULTRA_EXTDICT, + ZSTD_COMPRESSBLOCK_BTULTRA_EXTDICT + }, { ZSTD_compressBlock_fast_dictMatchState /* default for 0 */, ZSTD_compressBlock_fast_dictMatchState, - ZSTD_compressBlock_doubleFast_dictMatchState, - ZSTD_compressBlock_greedy_dictMatchState, - ZSTD_compressBlock_lazy_dictMatchState, - ZSTD_compressBlock_lazy2_dictMatchState, - ZSTD_compressBlock_btlazy2_dictMatchState, - ZSTD_compressBlock_btopt_dictMatchState, - ZSTD_compressBlock_btultra_dictMatchState, - ZSTD_compressBlock_btultra_dictMatchState }, + ZSTD_COMPRESSBLOCK_DOUBLEFAST_DICTMATCHSTATE, + ZSTD_COMPRESSBLOCK_GREEDY_DICTMATCHSTATE, + ZSTD_COMPRESSBLOCK_LAZY_DICTMATCHSTATE, + ZSTD_COMPRESSBLOCK_LAZY2_DICTMATCHSTATE, + ZSTD_COMPRESSBLOCK_BTLAZY2_DICTMATCHSTATE, + ZSTD_COMPRESSBLOCK_BTOPT_DICTMATCHSTATE, + ZSTD_COMPRESSBLOCK_BTULTRA_DICTMATCHSTATE, + ZSTD_COMPRESSBLOCK_BTULTRA_DICTMATCHSTATE + }, { NULL /* default for 0 */, NULL, NULL, - ZSTD_compressBlock_greedy_dedicatedDictSearch, - ZSTD_compressBlock_lazy_dedicatedDictSearch, - ZSTD_compressBlock_lazy2_dedicatedDictSearch, + ZSTD_COMPRESSBLOCK_GREEDY_DEDICATEDDICTSEARCH, + ZSTD_COMPRESSBLOCK_LAZY_DEDICATEDDICTSEARCH, + ZSTD_COMPRESSBLOCK_LAZY2_DEDICATEDDICTSEARCH, NULL, NULL, NULL, NULL } }; - ZSTD_blockCompressor selectedCompressor; + ZSTD_BlockCompressor_f selectedCompressor; ZSTD_STATIC_ASSERT((unsigned)ZSTD_fast == 1); - assert(ZSTD_cParam_withinBounds(ZSTD_c_strategy, strat)); - DEBUGLOG(4, "Selected block compressor: dictMode=%d strat=%d rowMatchfinder=%d", (int)dictMode, (int)strat, (int)useRowMatchFinder); + assert(ZSTD_cParam_withinBounds(ZSTD_c_strategy, (int)strat)); + DEBUGLOG(5, "Selected block compressor: dictMode=%d strat=%d rowMatchfinder=%d", (int)dictMode, (int)strat, (int)useRowMatchFinder); if (ZSTD_rowMatchFinderUsed(strat, useRowMatchFinder)) { - static const ZSTD_blockCompressor rowBasedBlockCompressors[4][3] = { - { ZSTD_compressBlock_greedy_row, - ZSTD_compressBlock_lazy_row, - ZSTD_compressBlock_lazy2_row }, - { ZSTD_compressBlock_greedy_extDict_row, - ZSTD_compressBlock_lazy_extDict_row, - ZSTD_compressBlock_lazy2_extDict_row }, - { ZSTD_compressBlock_greedy_dictMatchState_row, - ZSTD_compressBlock_lazy_dictMatchState_row, - ZSTD_compressBlock_lazy2_dictMatchState_row }, - { ZSTD_compressBlock_greedy_dedicatedDictSearch_row, - ZSTD_compressBlock_lazy_dedicatedDictSearch_row, - ZSTD_compressBlock_lazy2_dedicatedDictSearch_row } + static const ZSTD_BlockCompressor_f rowBasedBlockCompressors[4][3] = { + { + ZSTD_COMPRESSBLOCK_GREEDY_ROW, + ZSTD_COMPRESSBLOCK_LAZY_ROW, + ZSTD_COMPRESSBLOCK_LAZY2_ROW + }, + { + ZSTD_COMPRESSBLOCK_GREEDY_EXTDICT_ROW, + ZSTD_COMPRESSBLOCK_LAZY_EXTDICT_ROW, + ZSTD_COMPRESSBLOCK_LAZY2_EXTDICT_ROW + }, + { + ZSTD_COMPRESSBLOCK_GREEDY_DICTMATCHSTATE_ROW, + ZSTD_COMPRESSBLOCK_LAZY_DICTMATCHSTATE_ROW, + ZSTD_COMPRESSBLOCK_LAZY2_DICTMATCHSTATE_ROW + }, + { + ZSTD_COMPRESSBLOCK_GREEDY_DEDICATEDDICTSEARCH_ROW, + ZSTD_COMPRESSBLOCK_LAZY_DEDICATEDDICTSEARCH_ROW, + ZSTD_COMPRESSBLOCK_LAZY2_DEDICATEDDICTSEARCH_ROW + } }; - DEBUGLOG(4, "Selecting a row-based matchfinder"); + DEBUGLOG(5, "Selecting a row-based matchfinder"); assert(useRowMatchFinder != ZSTD_ps_auto); selectedCompressor = rowBasedBlockCompressors[(int)dictMode][(int)strat - (int)ZSTD_greedy]; } else { @@ -2704,30 +3065,126 @@ ZSTD_blockCompressor ZSTD_selectBlockCompressor(ZSTD_strategy strat, ZSTD_paramS return selectedCompressor; } -static void ZSTD_storeLastLiterals(seqStore_t* seqStorePtr, +static void ZSTD_storeLastLiterals(SeqStore_t* seqStorePtr, const BYTE* anchor, size_t lastLLSize) { ZSTD_memcpy(seqStorePtr->lit, anchor, lastLLSize); seqStorePtr->lit += lastLLSize; } -void ZSTD_resetSeqStore(seqStore_t* ssPtr) +void ZSTD_resetSeqStore(SeqStore_t* ssPtr) { ssPtr->lit = ssPtr->litStart; ssPtr->sequences = ssPtr->sequencesStart; ssPtr->longLengthType = ZSTD_llt_none; } -typedef enum { ZSTDbss_compress, ZSTDbss_noCompress } ZSTD_buildSeqStore_e; +/* ZSTD_postProcessSequenceProducerResult() : + * Validates and post-processes sequences obtained through the external matchfinder API: + * - Checks whether nbExternalSeqs represents an error condition. + * - Appends a block delimiter to outSeqs if one is not already present. + * See zstd.h for context regarding block delimiters. + * Returns the number of sequences after post-processing, or an error code. */ +static size_t ZSTD_postProcessSequenceProducerResult( + ZSTD_Sequence* outSeqs, size_t nbExternalSeqs, size_t outSeqsCapacity, size_t srcSize +) { + RETURN_ERROR_IF( + nbExternalSeqs > outSeqsCapacity, + sequenceProducer_failed, + "External sequence producer returned error code %lu", + (unsigned long)nbExternalSeqs + ); + + RETURN_ERROR_IF( + nbExternalSeqs == 0 && srcSize > 0, + sequenceProducer_failed, + "Got zero sequences from external sequence producer for a non-empty src buffer!" + ); + + if (srcSize == 0) { + ZSTD_memset(&outSeqs[0], 0, sizeof(ZSTD_Sequence)); + return 1; + } + + { + ZSTD_Sequence const lastSeq = outSeqs[nbExternalSeqs - 1]; + + /* We can return early if lastSeq is already a block delimiter. */ + if (lastSeq.offset == 0 && lastSeq.matchLength == 0) { + return nbExternalSeqs; + } + + /* This error condition is only possible if the external matchfinder + * produced an invalid parse, by definition of ZSTD_sequenceBound(). */ + RETURN_ERROR_IF( + nbExternalSeqs == outSeqsCapacity, + sequenceProducer_failed, + "nbExternalSeqs == outSeqsCapacity but lastSeq is not a block delimiter!" + ); + + /* lastSeq is not a block delimiter, so we need to append one. */ + ZSTD_memset(&outSeqs[nbExternalSeqs], 0, sizeof(ZSTD_Sequence)); + return nbExternalSeqs + 1; + } +} + +/* ZSTD_fastSequenceLengthSum() : + * Returns sum(litLen) + sum(matchLen) + lastLits for *seqBuf*. + * Similar to another function in zstd_compress.c (determine_blockSize), + * except it doesn't check for a block delimiter to end summation. + * Removing the early exit allows the compiler to auto-vectorize (https://godbolt.org/z/cY1cajz9P). + * This function can be deleted and replaced by determine_blockSize after we resolve issue #3456. */ +static size_t ZSTD_fastSequenceLengthSum(ZSTD_Sequence const* seqBuf, size_t seqBufSize) { + size_t matchLenSum, litLenSum, i; + matchLenSum = 0; + litLenSum = 0; + for (i = 0; i < seqBufSize; i++) { + litLenSum += seqBuf[i].litLength; + matchLenSum += seqBuf[i].matchLength; + } + return litLenSum + matchLenSum; +} + +/* + * Function to validate sequences produced by a block compressor. + */ +static void ZSTD_validateSeqStore(const SeqStore_t* seqStore, const ZSTD_compressionParameters* cParams) +{ +#if DEBUGLEVEL >= 1 + const SeqDef* seq = seqStore->sequencesStart; + const SeqDef* const seqEnd = seqStore->sequences; + size_t const matchLenLowerBound = cParams->minMatch == 3 ? 3 : 4; + for (; seq < seqEnd; ++seq) { + const ZSTD_SequenceLength seqLength = ZSTD_getSequenceLength(seqStore, seq); + assert(seqLength.matchLength >= matchLenLowerBound); + (void)seqLength; + (void)matchLenLowerBound; + } +#else + (void)seqStore; + (void)cParams; +#endif +} + +static size_t +ZSTD_transferSequences_wBlockDelim(ZSTD_CCtx* cctx, + ZSTD_SequencePosition* seqPos, + const ZSTD_Sequence* const inSeqs, size_t inSeqsSize, + const void* src, size_t blockSize, + ZSTD_ParamSwitch_e externalRepSearch); + +typedef enum { ZSTDbss_compress, ZSTDbss_noCompress } ZSTD_BuildSeqStore_e; static size_t ZSTD_buildSeqStore(ZSTD_CCtx* zc, const void* src, size_t srcSize) { - ZSTD_matchState_t* const ms = &zc->blockState.matchState; + ZSTD_MatchState_t* const ms = &zc->blockState.matchState; DEBUGLOG(5, "ZSTD_buildSeqStore (srcSize=%zu)", srcSize); assert(srcSize <= ZSTD_BLOCKSIZE_MAX); /* Assert that we have correctly flushed the ctx params into the ms's copy */ ZSTD_assertEqualCParams(zc->appliedParams.cParams, ms->cParams); - if (srcSize < MIN_CBLOCK_SIZE+ZSTD_blockHeaderSize+1) { + /* TODO: See 3090. We reduced MIN_CBLOCK_SIZE from 3 to 2 so to compensate we are adding + * additional 1. We need to revisit and change this logic to be more consistent */ + if (srcSize < MIN_CBLOCK_SIZE+ZSTD_blockHeaderSize+1+1) { if (zc->appliedParams.cParams.strategy >= ZSTD_btopt) { ZSTD_ldm_skipRawSeqStoreBytes(&zc->externSeqStore, srcSize); } else { @@ -2763,6 +3220,15 @@ static size_t ZSTD_buildSeqStore(ZSTD_CCtx* zc, const void* src, size_t srcSize) } if (zc->externSeqStore.pos < zc->externSeqStore.size) { assert(zc->appliedParams.ldmParams.enableLdm == ZSTD_ps_disable); + + /* External matchfinder + LDM is technically possible, just not implemented yet. + * We need to revisit soon and implement it. */ + RETURN_ERROR_IF( + ZSTD_hasExtSeqProd(&zc->appliedParams), + parameter_combination_unsupported, + "Long-distance matching with external sequence producer enabled is not currently supported." + ); + /* Updates ldmSeqStore.pos */ lastLLSize = ZSTD_ldm_blockCompress(&zc->externSeqStore, @@ -2772,7 +3238,15 @@ static size_t ZSTD_buildSeqStore(ZSTD_CCtx* zc, const void* src, size_t srcSize) src, srcSize); assert(zc->externSeqStore.pos <= zc->externSeqStore.size); } else if (zc->appliedParams.ldmParams.enableLdm == ZSTD_ps_enable) { - rawSeqStore_t ldmSeqStore = kNullRawSeqStore; + RawSeqStore_t ldmSeqStore = kNullRawSeqStore; + + /* External matchfinder + LDM is technically possible, just not implemented yet. + * We need to revisit soon and implement it. */ + RETURN_ERROR_IF( + ZSTD_hasExtSeqProd(&zc->appliedParams), + parameter_combination_unsupported, + "Long-distance matching with external sequence producer enabled is not currently supported." + ); ldmSeqStore.seq = zc->ldmSequences; ldmSeqStore.capacity = zc->maxNbLdmSequences; @@ -2788,42 +3262,116 @@ static size_t ZSTD_buildSeqStore(ZSTD_CCtx* zc, const void* src, size_t srcSize) zc->appliedParams.useRowMatchFinder, src, srcSize); assert(ldmSeqStore.pos == ldmSeqStore.size); - } else { /* not long range mode */ - ZSTD_blockCompressor const blockCompressor = ZSTD_selectBlockCompressor(zc->appliedParams.cParams.strategy, - zc->appliedParams.useRowMatchFinder, - dictMode); + } else if (ZSTD_hasExtSeqProd(&zc->appliedParams)) { + assert( + zc->extSeqBufCapacity >= ZSTD_sequenceBound(srcSize) + ); + assert(zc->appliedParams.extSeqProdFunc != NULL); + + { U32 const windowSize = (U32)1 << zc->appliedParams.cParams.windowLog; + + size_t const nbExternalSeqs = (zc->appliedParams.extSeqProdFunc)( + zc->appliedParams.extSeqProdState, + zc->extSeqBuf, + zc->extSeqBufCapacity, + src, srcSize, + NULL, 0, /* dict and dictSize, currently not supported */ + zc->appliedParams.compressionLevel, + windowSize + ); + + size_t const nbPostProcessedSeqs = ZSTD_postProcessSequenceProducerResult( + zc->extSeqBuf, + nbExternalSeqs, + zc->extSeqBufCapacity, + srcSize + ); + + /* Return early if there is no error, since we don't need to worry about last literals */ + if (!ZSTD_isError(nbPostProcessedSeqs)) { + ZSTD_SequencePosition seqPos = {0,0,0}; + size_t const seqLenSum = ZSTD_fastSequenceLengthSum(zc->extSeqBuf, nbPostProcessedSeqs); + RETURN_ERROR_IF(seqLenSum > srcSize, externalSequences_invalid, "External sequences imply too large a block!"); + FORWARD_IF_ERROR( + ZSTD_transferSequences_wBlockDelim( + zc, &seqPos, + zc->extSeqBuf, nbPostProcessedSeqs, + src, srcSize, + zc->appliedParams.searchForExternalRepcodes + ), + "Failed to copy external sequences to seqStore!" + ); + ms->ldmSeqStore = NULL; + DEBUGLOG(5, "Copied %lu sequences from external sequence producer to internal seqStore.", (unsigned long)nbExternalSeqs); + return ZSTDbss_compress; + } + + /* Propagate the error if fallback is disabled */ + if (!zc->appliedParams.enableMatchFinderFallback) { + return nbPostProcessedSeqs; + } + + /* Fallback to software matchfinder */ + { ZSTD_BlockCompressor_f const blockCompressor = + ZSTD_selectBlockCompressor( + zc->appliedParams.cParams.strategy, + zc->appliedParams.useRowMatchFinder, + dictMode); + ms->ldmSeqStore = NULL; + DEBUGLOG( + 5, + "External sequence producer returned error code %lu. Falling back to internal parser.", + (unsigned long)nbExternalSeqs + ); + lastLLSize = blockCompressor(ms, &zc->seqStore, zc->blockState.nextCBlock->rep, src, srcSize); + } } + } else { /* not long range mode and no external matchfinder */ + ZSTD_BlockCompressor_f const blockCompressor = ZSTD_selectBlockCompressor( + zc->appliedParams.cParams.strategy, + zc->appliedParams.useRowMatchFinder, + dictMode); ms->ldmSeqStore = NULL; lastLLSize = blockCompressor(ms, &zc->seqStore, zc->blockState.nextCBlock->rep, src, srcSize); } { const BYTE* const lastLiterals = (const BYTE*)src + srcSize - lastLLSize; ZSTD_storeLastLiterals(&zc->seqStore, lastLiterals, lastLLSize); } } + ZSTD_validateSeqStore(&zc->seqStore, &zc->appliedParams.cParams); return ZSTDbss_compress; } -static void ZSTD_copyBlockSequences(ZSTD_CCtx* zc) +static size_t ZSTD_copyBlockSequences(SeqCollector* seqCollector, const SeqStore_t* seqStore, const U32 prevRepcodes[ZSTD_REP_NUM]) { - const seqStore_t* seqStore = ZSTD_getSeqStore(zc); - const seqDef* seqStoreSeqs = seqStore->sequencesStart; - size_t seqStoreSeqSize = seqStore->sequences - seqStoreSeqs; - size_t seqStoreLiteralsSize = (size_t)(seqStore->lit - seqStore->litStart); - size_t literalsRead = 0; - size_t lastLLSize; + const SeqDef* inSeqs = seqStore->sequencesStart; + const size_t nbInSequences = (size_t)(seqStore->sequences - inSeqs); + const size_t nbInLiterals = (size_t)(seqStore->lit - seqStore->litStart); - ZSTD_Sequence* outSeqs = &zc->seqCollector.seqStart[zc->seqCollector.seqIndex]; + ZSTD_Sequence* outSeqs = seqCollector->seqIndex == 0 ? seqCollector->seqStart : seqCollector->seqStart + seqCollector->seqIndex; + const size_t nbOutSequences = nbInSequences + 1; + size_t nbOutLiterals = 0; + Repcodes_t repcodes; size_t i; - repcodes_t updatedRepcodes; - - assert(zc->seqCollector.seqIndex + 1 < zc->seqCollector.maxSequences); - /* Ensure we have enough space for last literals "sequence" */ - assert(zc->seqCollector.maxSequences >= seqStoreSeqSize + 1); - ZSTD_memcpy(updatedRepcodes.rep, zc->blockState.prevCBlock->rep, sizeof(repcodes_t)); - for (i = 0; i < seqStoreSeqSize; ++i) { - U32 rawOffset = seqStoreSeqs[i].offBase - ZSTD_REP_NUM; - outSeqs[i].litLength = seqStoreSeqs[i].litLength; - outSeqs[i].matchLength = seqStoreSeqs[i].mlBase + MINMATCH; + + /* Bounds check that we have enough space for every input sequence + * and the block delimiter + */ + assert(seqCollector->seqIndex <= seqCollector->maxSequences); + RETURN_ERROR_IF( + nbOutSequences > (size_t)(seqCollector->maxSequences - seqCollector->seqIndex), + dstSize_tooSmall, + "Not enough space to copy sequences"); + + ZSTD_memcpy(&repcodes, prevRepcodes, sizeof(repcodes)); + for (i = 0; i < nbInSequences; ++i) { + U32 rawOffset; + outSeqs[i].litLength = inSeqs[i].litLength; + outSeqs[i].matchLength = inSeqs[i].mlBase + MINMATCH; outSeqs[i].rep = 0; + /* Handle the possible single length >= 64K + * There can only be one because we add MINMATCH to every match length, + * and blocks are at most 128K. + */ if (i == seqStore->longLengthPos) { if (seqStore->longLengthType == ZSTD_llt_literalLength) { outSeqs[i].litLength += 0x10000; @@ -2832,46 +3380,75 @@ static void ZSTD_copyBlockSequences(ZSTD_CCtx* zc) } } - if (seqStoreSeqs[i].offBase <= ZSTD_REP_NUM) { - /* Derive the correct offset corresponding to a repcode */ - outSeqs[i].rep = seqStoreSeqs[i].offBase; + /* Determine the raw offset given the offBase, which may be a repcode. */ + if (OFFBASE_IS_REPCODE(inSeqs[i].offBase)) { + const U32 repcode = OFFBASE_TO_REPCODE(inSeqs[i].offBase); + assert(repcode > 0); + outSeqs[i].rep = repcode; if (outSeqs[i].litLength != 0) { - rawOffset = updatedRepcodes.rep[outSeqs[i].rep - 1]; + rawOffset = repcodes.rep[repcode - 1]; } else { - if (outSeqs[i].rep == 3) { - rawOffset = updatedRepcodes.rep[0] - 1; + if (repcode == 3) { + assert(repcodes.rep[0] > 1); + rawOffset = repcodes.rep[0] - 1; } else { - rawOffset = updatedRepcodes.rep[outSeqs[i].rep]; + rawOffset = repcodes.rep[repcode]; } } + } else { + rawOffset = OFFBASE_TO_OFFSET(inSeqs[i].offBase); } outSeqs[i].offset = rawOffset; - /* seqStoreSeqs[i].offset == offCode+1, and ZSTD_updateRep() expects offCode - so we provide seqStoreSeqs[i].offset - 1 */ - ZSTD_updateRep(updatedRepcodes.rep, - seqStoreSeqs[i].offBase - 1, - seqStoreSeqs[i].litLength == 0); - literalsRead += outSeqs[i].litLength; + + /* Update repcode history for the sequence */ + ZSTD_updateRep(repcodes.rep, + inSeqs[i].offBase, + inSeqs[i].litLength == 0); + + nbOutLiterals += outSeqs[i].litLength; } /* Insert last literals (if any exist) in the block as a sequence with ml == off == 0. * If there are no last literals, then we'll emit (of: 0, ml: 0, ll: 0), which is a marker * for the block boundary, according to the API. */ - assert(seqStoreLiteralsSize >= literalsRead); - lastLLSize = seqStoreLiteralsSize - literalsRead; - outSeqs[i].litLength = (U32)lastLLSize; - outSeqs[i].matchLength = outSeqs[i].offset = outSeqs[i].rep = 0; - seqStoreSeqSize++; - zc->seqCollector.seqIndex += seqStoreSeqSize; + assert(nbInLiterals >= nbOutLiterals); + { + const size_t lastLLSize = nbInLiterals - nbOutLiterals; + outSeqs[nbInSequences].litLength = (U32)lastLLSize; + outSeqs[nbInSequences].matchLength = 0; + outSeqs[nbInSequences].offset = 0; + assert(nbOutSequences == nbInSequences + 1); + } + seqCollector->seqIndex += nbOutSequences; + assert(seqCollector->seqIndex <= seqCollector->maxSequences); + + return 0; +} + +size_t ZSTD_sequenceBound(size_t srcSize) { + const size_t maxNbSeq = (srcSize / ZSTD_MINMATCH_MIN) + 1; + const size_t maxNbDelims = (srcSize / ZSTD_BLOCKSIZE_MAX_MIN) + 1; + return maxNbSeq + maxNbDelims; } size_t ZSTD_generateSequences(ZSTD_CCtx* zc, ZSTD_Sequence* outSeqs, size_t outSeqsSize, const void* src, size_t srcSize) { const size_t dstCapacity = ZSTD_compressBound(srcSize); - void* dst = ZSTD_customMalloc(dstCapacity, ZSTD_defaultCMem); + void* dst; /* Make C90 happy. */ SeqCollector seqCollector; + { + int targetCBlockSize; + FORWARD_IF_ERROR(ZSTD_CCtx_getParameter(zc, ZSTD_c_targetCBlockSize, &targetCBlockSize), ""); + RETURN_ERROR_IF(targetCBlockSize != 0, parameter_unsupported, "targetCBlockSize != 0"); + } + { + int nbWorkers; + FORWARD_IF_ERROR(ZSTD_CCtx_getParameter(zc, ZSTD_c_nbWorkers, &nbWorkers), ""); + RETURN_ERROR_IF(nbWorkers != 0, parameter_unsupported, "nbWorkers != 0"); + } + dst = ZSTD_customMalloc(dstCapacity, ZSTD_defaultCMem); RETURN_ERROR_IF(dst == NULL, memory_allocation, "NULL pointer!"); seqCollector.collectSequences = 1; @@ -2880,8 +3457,12 @@ size_t ZSTD_generateSequences(ZSTD_CCtx* zc, ZSTD_Sequence* outSeqs, seqCollector.maxSequences = outSeqsSize; zc->seqCollector = seqCollector; - ZSTD_compress2(zc, dst, dstCapacity, src, srcSize); - ZSTD_customFree(dst, ZSTD_defaultCMem); + { + const size_t ret = ZSTD_compress2(zc, dst, dstCapacity, src, srcSize); + ZSTD_customFree(dst, ZSTD_defaultCMem); + FORWARD_IF_ERROR(ret, "ZSTD_compress2 failed"); + } + assert(zc->seqCollector.seqIndex <= ZSTD_sequenceBound(srcSize)); return zc->seqCollector.seqIndex; } @@ -2910,19 +3491,17 @@ static int ZSTD_isRLE(const BYTE* src, size_t length) { const size_t unrollMask = unrollSize - 1; const size_t prefixLength = length & unrollMask; size_t i; - size_t u; if (length == 1) return 1; /* Check if prefix is RLE first before using unrolled loop */ if (prefixLength && ZSTD_count(ip+1, ip, ip+prefixLength) != prefixLength-1) { return 0; } for (i = prefixLength; i != length; i += unrollSize) { + size_t u; for (u = 0; u < unrollSize; u += sizeof(size_t)) { if (MEM_readST(ip + i + u) != valueST) { return 0; - } - } - } + } } } return 1; } @@ -2930,7 +3509,7 @@ static int ZSTD_isRLE(const BYTE* src, size_t length) { * This is just a heuristic based on the compressibility. * It may return both false positives and false negatives. */ -static int ZSTD_maybeRLE(seqStore_t const* seqStore) +static int ZSTD_maybeRLE(SeqStore_t const* seqStore) { size_t const nbSeqs = (size_t)(seqStore->sequences - seqStore->sequencesStart); size_t const nbLits = (size_t)(seqStore->lit - seqStore->litStart); @@ -2938,7 +3517,8 @@ static int ZSTD_maybeRLE(seqStore_t const* seqStore) return nbSeqs < 4 && nbLits < 10; } -static void ZSTD_blockState_confirmRepcodesAndEntropyTables(ZSTD_blockState_t* const bs) +static void +ZSTD_blockState_confirmRepcodesAndEntropyTables(ZSTD_blockState_t* const bs) { ZSTD_compressedBlockState_t* const tmp = bs->prevCBlock; bs->prevCBlock = bs->nextCBlock; @@ -2946,12 +3526,14 @@ static void ZSTD_blockState_confirmRepcodesAndEntropyTables(ZSTD_blockState_t* c } /* Writes the block header */ -static void writeBlockHeader(void* op, size_t cSize, size_t blockSize, U32 lastBlock) { +static void +writeBlockHeader(void* op, size_t cSize, size_t blockSize, U32 lastBlock) +{ U32 const cBlockHeader = cSize == 1 ? lastBlock + (((U32)bt_rle)<<1) + (U32)(blockSize << 3) : lastBlock + (((U32)bt_compressed)<<1) + (U32)(cSize << 3); MEM_writeLE24(op, cBlockHeader); - DEBUGLOG(3, "writeBlockHeader: cSize: %zu blockSize: %zu lastBlock: %u", cSize, blockSize, lastBlock); + DEBUGLOG(5, "writeBlockHeader: cSize: %zu blockSize: %zu lastBlock: %u", cSize, blockSize, lastBlock); } /* ZSTD_buildBlockEntropyStats_literals() : @@ -2959,13 +3541,16 @@ static void writeBlockHeader(void* op, size_t cSize, size_t blockSize, U32 lastB * Stores literals block type (raw, rle, compressed, repeat) and * huffman description table to hufMetadata. * Requires ENTROPY_WORKSPACE_SIZE workspace - * @return : size of huffman description table or error code */ -static size_t ZSTD_buildBlockEntropyStats_literals(void* const src, size_t srcSize, - const ZSTD_hufCTables_t* prevHuf, - ZSTD_hufCTables_t* nextHuf, - ZSTD_hufCTablesMetadata_t* hufMetadata, - const int literalsCompressionIsDisabled, - void* workspace, size_t wkspSize) + * @return : size of huffman description table, or an error code + */ +static size_t +ZSTD_buildBlockEntropyStats_literals(void* const src, size_t srcSize, + const ZSTD_hufCTables_t* prevHuf, + ZSTD_hufCTables_t* nextHuf, + ZSTD_hufCTablesMetadata_t* hufMetadata, + const int literalsCompressionIsDisabled, + void* workspace, size_t wkspSize, + int hufFlags) { BYTE* const wkspStart = (BYTE*)workspace; BYTE* const wkspEnd = wkspStart + wkspSize; @@ -2973,9 +3558,9 @@ static size_t ZSTD_buildBlockEntropyStats_literals(void* const src, size_t srcSi unsigned* const countWksp = (unsigned*)workspace; const size_t countWkspSize = (HUF_SYMBOLVALUE_MAX + 1) * sizeof(unsigned); BYTE* const nodeWksp = countWkspStart + countWkspSize; - const size_t nodeWkspSize = wkspEnd-nodeWksp; + const size_t nodeWkspSize = (size_t)(wkspEnd - nodeWksp); unsigned maxSymbolValue = HUF_SYMBOLVALUE_MAX; - unsigned huffLog = HUF_TABLELOG_DEFAULT; + unsigned huffLog = LitHufLog; HUF_repeat repeat = prevHuf->repeatMode; DEBUGLOG(5, "ZSTD_buildBlockEntropyStats_literals (srcSize=%zu)", srcSize); @@ -2990,73 +3575,77 @@ static size_t ZSTD_buildBlockEntropyStats_literals(void* const src, size_t srcSi /* small ? don't even attempt compression (speed opt) */ #ifndef COMPRESS_LITERALS_SIZE_MIN -#define COMPRESS_LITERALS_SIZE_MIN 63 +# define COMPRESS_LITERALS_SIZE_MIN 63 /* heuristic */ #endif { size_t const minLitSize = (prevHuf->repeatMode == HUF_repeat_valid) ? 6 : COMPRESS_LITERALS_SIZE_MIN; if (srcSize <= minLitSize) { DEBUGLOG(5, "set_basic - too small"); hufMetadata->hType = set_basic; return 0; - } - } + } } /* Scan input and build symbol stats */ - { size_t const largest = HIST_count_wksp (countWksp, &maxSymbolValue, (const BYTE*)src, srcSize, workspace, wkspSize); + { size_t const largest = + HIST_count_wksp (countWksp, &maxSymbolValue, + (const BYTE*)src, srcSize, + workspace, wkspSize); FORWARD_IF_ERROR(largest, "HIST_count_wksp failed"); if (largest == srcSize) { + /* only one literal symbol */ DEBUGLOG(5, "set_rle"); hufMetadata->hType = set_rle; return 0; } if (largest <= (srcSize >> 7)+4) { + /* heuristic: likely not compressible */ DEBUGLOG(5, "set_basic - no gain"); hufMetadata->hType = set_basic; return 0; - } - } + } } /* Validate the previous Huffman table */ - if (repeat == HUF_repeat_check && !HUF_validateCTable((HUF_CElt const*)prevHuf->CTable, countWksp, maxSymbolValue)) { + if (repeat == HUF_repeat_check + && !HUF_validateCTable((HUF_CElt const*)prevHuf->CTable, countWksp, maxSymbolValue)) { repeat = HUF_repeat_none; } /* Build Huffman Tree */ ZSTD_memset(nextHuf->CTable, 0, sizeof(nextHuf->CTable)); - huffLog = HUF_optimalTableLog(huffLog, srcSize, maxSymbolValue); + huffLog = HUF_optimalTableLog(huffLog, srcSize, maxSymbolValue, nodeWksp, nodeWkspSize, nextHuf->CTable, countWksp, hufFlags); + assert(huffLog <= LitHufLog); { size_t const maxBits = HUF_buildCTable_wksp((HUF_CElt*)nextHuf->CTable, countWksp, maxSymbolValue, huffLog, nodeWksp, nodeWkspSize); FORWARD_IF_ERROR(maxBits, "HUF_buildCTable_wksp"); huffLog = (U32)maxBits; - { /* Build and write the CTable */ - size_t const newCSize = HUF_estimateCompressedSize( - (HUF_CElt*)nextHuf->CTable, countWksp, maxSymbolValue); - size_t const hSize = HUF_writeCTable_wksp( - hufMetadata->hufDesBuffer, sizeof(hufMetadata->hufDesBuffer), - (HUF_CElt*)nextHuf->CTable, maxSymbolValue, huffLog, - nodeWksp, nodeWkspSize); - /* Check against repeating the previous CTable */ - if (repeat != HUF_repeat_none) { - size_t const oldCSize = HUF_estimateCompressedSize( - (HUF_CElt const*)prevHuf->CTable, countWksp, maxSymbolValue); - if (oldCSize < srcSize && (oldCSize <= hSize + newCSize || hSize + 12 >= srcSize)) { - DEBUGLOG(5, "set_repeat - smaller"); - ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); - hufMetadata->hType = set_repeat; - return 0; - } - } - if (newCSize + hSize >= srcSize) { - DEBUGLOG(5, "set_basic - no gains"); + } + { /* Build and write the CTable */ + size_t const newCSize = HUF_estimateCompressedSize( + (HUF_CElt*)nextHuf->CTable, countWksp, maxSymbolValue); + size_t const hSize = HUF_writeCTable_wksp( + hufMetadata->hufDesBuffer, sizeof(hufMetadata->hufDesBuffer), + (HUF_CElt*)nextHuf->CTable, maxSymbolValue, huffLog, + nodeWksp, nodeWkspSize); + /* Check against repeating the previous CTable */ + if (repeat != HUF_repeat_none) { + size_t const oldCSize = HUF_estimateCompressedSize( + (HUF_CElt const*)prevHuf->CTable, countWksp, maxSymbolValue); + if (oldCSize < srcSize && (oldCSize <= hSize + newCSize || hSize + 12 >= srcSize)) { + DEBUGLOG(5, "set_repeat - smaller"); ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); - hufMetadata->hType = set_basic; + hufMetadata->hType = set_repeat; return 0; - } - DEBUGLOG(5, "set_compressed (hSize=%u)", (U32)hSize); - hufMetadata->hType = set_compressed; - nextHuf->repeatMode = HUF_repeat_check; - return hSize; + } } + if (newCSize + hSize >= srcSize) { + DEBUGLOG(5, "set_basic - no gains"); + ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); + hufMetadata->hType = set_basic; + return 0; } + DEBUGLOG(5, "set_compressed (hSize=%u)", (U32)hSize); + hufMetadata->hType = set_compressed; + nextHuf->repeatMode = HUF_repeat_check; + return hSize; } } @@ -3066,8 +3655,9 @@ static size_t ZSTD_buildBlockEntropyStats_literals(void* const src, size_t srcSi * and updates nextEntropy to the appropriate repeatMode. */ static ZSTD_symbolEncodingTypeStats_t -ZSTD_buildDummySequencesStatistics(ZSTD_fseCTables_t* nextEntropy) { - ZSTD_symbolEncodingTypeStats_t stats = {set_basic, set_basic, set_basic, 0, 0}; +ZSTD_buildDummySequencesStatistics(ZSTD_fseCTables_t* nextEntropy) +{ + ZSTD_symbolEncodingTypeStats_t stats = {set_basic, set_basic, set_basic, 0, 0, 0}; nextEntropy->litlength_repeatMode = FSE_repeat_none; nextEntropy->offcode_repeatMode = FSE_repeat_none; nextEntropy->matchlength_repeatMode = FSE_repeat_none; @@ -3078,16 +3668,18 @@ ZSTD_buildDummySequencesStatistics(ZSTD_fseCTables_t* nextEntropy) { * Builds entropy for the sequences. * Stores symbol compression modes and fse table to fseMetadata. * Requires ENTROPY_WORKSPACE_SIZE wksp. - * @return : size of fse tables or error code */ -static size_t ZSTD_buildBlockEntropyStats_sequences(seqStore_t* seqStorePtr, - const ZSTD_fseCTables_t* prevEntropy, - ZSTD_fseCTables_t* nextEntropy, - const ZSTD_CCtx_params* cctxParams, - ZSTD_fseCTablesMetadata_t* fseMetadata, - void* workspace, size_t wkspSize) + * @return : size of fse tables or error code */ +static size_t +ZSTD_buildBlockEntropyStats_sequences( + const SeqStore_t* seqStorePtr, + const ZSTD_fseCTables_t* prevEntropy, + ZSTD_fseCTables_t* nextEntropy, + const ZSTD_CCtx_params* cctxParams, + ZSTD_fseCTablesMetadata_t* fseMetadata, + void* workspace, size_t wkspSize) { ZSTD_strategy const strategy = cctxParams->cParams.strategy; - size_t const nbSeq = seqStorePtr->sequences - seqStorePtr->sequencesStart; + size_t const nbSeq = (size_t)(seqStorePtr->sequences - seqStorePtr->sequencesStart); BYTE* const ostart = fseMetadata->fseTablesBuffer; BYTE* const oend = ostart + sizeof(fseMetadata->fseTablesBuffer); BYTE* op = ostart; @@ -3103,9 +3695,9 @@ static size_t ZSTD_buildBlockEntropyStats_sequences(seqStore_t* seqStorePtr, entropyWorkspace, entropyWorkspaceSize) : ZSTD_buildDummySequencesStatistics(nextEntropy); FORWARD_IF_ERROR(stats.size, "ZSTD_buildSequencesStatistics failed!"); - fseMetadata->llType = (symbolEncodingType_e) stats.LLtype; - fseMetadata->ofType = (symbolEncodingType_e) stats.Offtype; - fseMetadata->mlType = (symbolEncodingType_e) stats.MLtype; + fseMetadata->llType = (SymbolEncodingType_e) stats.LLtype; + fseMetadata->ofType = (SymbolEncodingType_e) stats.Offtype; + fseMetadata->mlType = (SymbolEncodingType_e) stats.MLtype; fseMetadata->lastCountSize = stats.lastCountSize; return stats.size; } @@ -3114,23 +3706,28 @@ static size_t ZSTD_buildBlockEntropyStats_sequences(seqStore_t* seqStorePtr, /* ZSTD_buildBlockEntropyStats() : * Builds entropy for the block. * Requires workspace size ENTROPY_WORKSPACE_SIZE - * - * @return : 0 on success or error code + * @return : 0 on success, or an error code + * Note : also employed in superblock */ -size_t ZSTD_buildBlockEntropyStats(seqStore_t* seqStorePtr, - const ZSTD_entropyCTables_t* prevEntropy, - ZSTD_entropyCTables_t* nextEntropy, - const ZSTD_CCtx_params* cctxParams, - ZSTD_entropyCTablesMetadata_t* entropyMetadata, - void* workspace, size_t wkspSize) -{ - size_t const litSize = seqStorePtr->lit - seqStorePtr->litStart; +size_t ZSTD_buildBlockEntropyStats( + const SeqStore_t* seqStorePtr, + const ZSTD_entropyCTables_t* prevEntropy, + ZSTD_entropyCTables_t* nextEntropy, + const ZSTD_CCtx_params* cctxParams, + ZSTD_entropyCTablesMetadata_t* entropyMetadata, + void* workspace, size_t wkspSize) +{ + size_t const litSize = (size_t)(seqStorePtr->lit - seqStorePtr->litStart); + int const huf_useOptDepth = (cctxParams->cParams.strategy >= HUF_OPTIMAL_DEPTH_THRESHOLD); + int const hufFlags = huf_useOptDepth ? HUF_flags_optimalDepth : 0; + entropyMetadata->hufMetadata.hufDesSize = ZSTD_buildBlockEntropyStats_literals(seqStorePtr->litStart, litSize, &prevEntropy->huf, &nextEntropy->huf, &entropyMetadata->hufMetadata, ZSTD_literalsCompressionIsDisabled(cctxParams), - workspace, wkspSize); + workspace, wkspSize, hufFlags); + FORWARD_IF_ERROR(entropyMetadata->hufMetadata.hufDesSize, "ZSTD_buildBlockEntropyStats_literals failed"); entropyMetadata->fseMetadata.fseTablesSize = ZSTD_buildBlockEntropyStats_sequences(seqStorePtr, @@ -3143,11 +3740,12 @@ size_t ZSTD_buildBlockEntropyStats(seqStore_t* seqStorePtr, } /* Returns the size estimate for the literals section (header + content) of a block */ -static size_t ZSTD_estimateBlockSize_literal(const BYTE* literals, size_t litSize, - const ZSTD_hufCTables_t* huf, - const ZSTD_hufCTablesMetadata_t* hufMetadata, - void* workspace, size_t wkspSize, - int writeEntropy) +static size_t +ZSTD_estimateBlockSize_literal(const BYTE* literals, size_t litSize, + const ZSTD_hufCTables_t* huf, + const ZSTD_hufCTablesMetadata_t* hufMetadata, + void* workspace, size_t wkspSize, + int writeEntropy) { unsigned* const countWksp = (unsigned*)workspace; unsigned maxSymbolValue = HUF_SYMBOLVALUE_MAX; @@ -3169,12 +3767,13 @@ static size_t ZSTD_estimateBlockSize_literal(const BYTE* literals, size_t litSiz } /* Returns the size estimate for the FSE-compressed symbols (of, ml, ll) of a block */ -static size_t ZSTD_estimateBlockSize_symbolType(symbolEncodingType_e type, - const BYTE* codeTable, size_t nbSeq, unsigned maxCode, - const FSE_CTable* fseCTable, - const U8* additionalBits, - short const* defaultNorm, U32 defaultNormLog, U32 defaultMax, - void* workspace, size_t wkspSize) +static size_t +ZSTD_estimateBlockSize_symbolType(SymbolEncodingType_e type, + const BYTE* codeTable, size_t nbSeq, unsigned maxCode, + const FSE_CTable* fseCTable, + const U8* additionalBits, + short const* defaultNorm, U32 defaultNormLog, U32 defaultMax, + void* workspace, size_t wkspSize) { unsigned* const countWksp = (unsigned*)workspace; const BYTE* ctp = codeTable; @@ -3206,116 +3805,121 @@ static size_t ZSTD_estimateBlockSize_symbolType(symbolEncodingType_e type, } /* Returns the size estimate for the sequences section (header + content) of a block */ -static size_t ZSTD_estimateBlockSize_sequences(const BYTE* ofCodeTable, - const BYTE* llCodeTable, - const BYTE* mlCodeTable, - size_t nbSeq, - const ZSTD_fseCTables_t* fseTables, - const ZSTD_fseCTablesMetadata_t* fseMetadata, - void* workspace, size_t wkspSize, - int writeEntropy) +static size_t +ZSTD_estimateBlockSize_sequences(const BYTE* ofCodeTable, + const BYTE* llCodeTable, + const BYTE* mlCodeTable, + size_t nbSeq, + const ZSTD_fseCTables_t* fseTables, + const ZSTD_fseCTablesMetadata_t* fseMetadata, + void* workspace, size_t wkspSize, + int writeEntropy) { size_t sequencesSectionHeaderSize = 1 /* seqHead */ + 1 /* min seqSize size */ + (nbSeq >= 128) + (nbSeq >= LONGNBSEQ); size_t cSeqSizeEstimate = 0; cSeqSizeEstimate += ZSTD_estimateBlockSize_symbolType(fseMetadata->ofType, ofCodeTable, nbSeq, MaxOff, - fseTables->offcodeCTable, NULL, - OF_defaultNorm, OF_defaultNormLog, DefaultMaxOff, - workspace, wkspSize); + fseTables->offcodeCTable, NULL, + OF_defaultNorm, OF_defaultNormLog, DefaultMaxOff, + workspace, wkspSize); cSeqSizeEstimate += ZSTD_estimateBlockSize_symbolType(fseMetadata->llType, llCodeTable, nbSeq, MaxLL, - fseTables->litlengthCTable, LL_bits, - LL_defaultNorm, LL_defaultNormLog, MaxLL, - workspace, wkspSize); + fseTables->litlengthCTable, LL_bits, + LL_defaultNorm, LL_defaultNormLog, MaxLL, + workspace, wkspSize); cSeqSizeEstimate += ZSTD_estimateBlockSize_symbolType(fseMetadata->mlType, mlCodeTable, nbSeq, MaxML, - fseTables->matchlengthCTable, ML_bits, - ML_defaultNorm, ML_defaultNormLog, MaxML, - workspace, wkspSize); + fseTables->matchlengthCTable, ML_bits, + ML_defaultNorm, ML_defaultNormLog, MaxML, + workspace, wkspSize); if (writeEntropy) cSeqSizeEstimate += fseMetadata->fseTablesSize; return cSeqSizeEstimate + sequencesSectionHeaderSize; } /* Returns the size estimate for a given stream of literals, of, ll, ml */ -static size_t ZSTD_estimateBlockSize(const BYTE* literals, size_t litSize, - const BYTE* ofCodeTable, - const BYTE* llCodeTable, - const BYTE* mlCodeTable, - size_t nbSeq, - const ZSTD_entropyCTables_t* entropy, - const ZSTD_entropyCTablesMetadata_t* entropyMetadata, - void* workspace, size_t wkspSize, - int writeLitEntropy, int writeSeqEntropy) { +static size_t +ZSTD_estimateBlockSize(const BYTE* literals, size_t litSize, + const BYTE* ofCodeTable, + const BYTE* llCodeTable, + const BYTE* mlCodeTable, + size_t nbSeq, + const ZSTD_entropyCTables_t* entropy, + const ZSTD_entropyCTablesMetadata_t* entropyMetadata, + void* workspace, size_t wkspSize, + int writeLitEntropy, int writeSeqEntropy) +{ size_t const literalsSize = ZSTD_estimateBlockSize_literal(literals, litSize, - &entropy->huf, &entropyMetadata->hufMetadata, - workspace, wkspSize, writeLitEntropy); + &entropy->huf, &entropyMetadata->hufMetadata, + workspace, wkspSize, writeLitEntropy); size_t const seqSize = ZSTD_estimateBlockSize_sequences(ofCodeTable, llCodeTable, mlCodeTable, - nbSeq, &entropy->fse, &entropyMetadata->fseMetadata, - workspace, wkspSize, writeSeqEntropy); + nbSeq, &entropy->fse, &entropyMetadata->fseMetadata, + workspace, wkspSize, writeSeqEntropy); return seqSize + literalsSize + ZSTD_blockHeaderSize; } /* Builds entropy statistics and uses them for blocksize estimation. * - * Returns the estimated compressed size of the seqStore, or a zstd error. + * @return: estimated compressed size of the seqStore, or a zstd error. */ -static size_t ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize(seqStore_t* seqStore, ZSTD_CCtx* zc) { - ZSTD_entropyCTablesMetadata_t* entropyMetadata = &zc->blockSplitCtx.entropyMetadata; +static size_t +ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize(SeqStore_t* seqStore, ZSTD_CCtx* zc) +{ + ZSTD_entropyCTablesMetadata_t* const entropyMetadata = &zc->blockSplitCtx.entropyMetadata; DEBUGLOG(6, "ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize()"); FORWARD_IF_ERROR(ZSTD_buildBlockEntropyStats(seqStore, &zc->blockState.prevCBlock->entropy, &zc->blockState.nextCBlock->entropy, &zc->appliedParams, entropyMetadata, - zc->entropyWorkspace, ENTROPY_WORKSPACE_SIZE /* statically allocated in resetCCtx */), ""); - return ZSTD_estimateBlockSize(seqStore->litStart, (size_t)(seqStore->lit - seqStore->litStart), + zc->tmpWorkspace, zc->tmpWkspSize), ""); + return ZSTD_estimateBlockSize( + seqStore->litStart, (size_t)(seqStore->lit - seqStore->litStart), seqStore->ofCode, seqStore->llCode, seqStore->mlCode, (size_t)(seqStore->sequences - seqStore->sequencesStart), - &zc->blockState.nextCBlock->entropy, entropyMetadata, zc->entropyWorkspace, ENTROPY_WORKSPACE_SIZE, + &zc->blockState.nextCBlock->entropy, + entropyMetadata, + zc->tmpWorkspace, zc->tmpWkspSize, (int)(entropyMetadata->hufMetadata.hType == set_compressed), 1); } /* Returns literals bytes represented in a seqStore */ -static size_t ZSTD_countSeqStoreLiteralsBytes(const seqStore_t* const seqStore) { +static size_t ZSTD_countSeqStoreLiteralsBytes(const SeqStore_t* const seqStore) +{ size_t literalsBytes = 0; - size_t const nbSeqs = seqStore->sequences - seqStore->sequencesStart; + size_t const nbSeqs = (size_t)(seqStore->sequences - seqStore->sequencesStart); size_t i; for (i = 0; i < nbSeqs; ++i) { - seqDef seq = seqStore->sequencesStart[i]; + SeqDef const seq = seqStore->sequencesStart[i]; literalsBytes += seq.litLength; if (i == seqStore->longLengthPos && seqStore->longLengthType == ZSTD_llt_literalLength) { literalsBytes += 0x10000; - } - } + } } return literalsBytes; } /* Returns match bytes represented in a seqStore */ -static size_t ZSTD_countSeqStoreMatchBytes(const seqStore_t* const seqStore) { +static size_t ZSTD_countSeqStoreMatchBytes(const SeqStore_t* const seqStore) +{ size_t matchBytes = 0; - size_t const nbSeqs = seqStore->sequences - seqStore->sequencesStart; + size_t const nbSeqs = (size_t)(seqStore->sequences - seqStore->sequencesStart); size_t i; for (i = 0; i < nbSeqs; ++i) { - seqDef seq = seqStore->sequencesStart[i]; + SeqDef seq = seqStore->sequencesStart[i]; matchBytes += seq.mlBase + MINMATCH; if (i == seqStore->longLengthPos && seqStore->longLengthType == ZSTD_llt_matchLength) { matchBytes += 0x10000; - } - } + } } return matchBytes; } /* Derives the seqStore that is a chunk of the originalSeqStore from [startIdx, endIdx). * Stores the result in resultSeqStore. */ -static void ZSTD_deriveSeqStoreChunk(seqStore_t* resultSeqStore, - const seqStore_t* originalSeqStore, - size_t startIdx, size_t endIdx) { - BYTE* const litEnd = originalSeqStore->lit; - size_t literalsBytes; - size_t literalsBytesPreceding = 0; - +static void ZSTD_deriveSeqStoreChunk(SeqStore_t* resultSeqStore, + const SeqStore_t* originalSeqStore, + size_t startIdx, size_t endIdx) +{ *resultSeqStore = *originalSeqStore; if (startIdx > 0) { resultSeqStore->sequences = originalSeqStore->sequencesStart + startIdx; - literalsBytesPreceding = ZSTD_countSeqStoreLiteralsBytes(resultSeqStore); + resultSeqStore->litStart += ZSTD_countSeqStoreLiteralsBytes(resultSeqStore); } /* Move longLengthPos into the correct position if necessary */ @@ -3328,13 +3932,12 @@ static void ZSTD_deriveSeqStoreChunk(seqStore_t* resultSeqStore, } resultSeqStore->sequencesStart = originalSeqStore->sequencesStart + startIdx; resultSeqStore->sequences = originalSeqStore->sequencesStart + endIdx; - literalsBytes = ZSTD_countSeqStoreLiteralsBytes(resultSeqStore); - resultSeqStore->litStart += literalsBytesPreceding; if (endIdx == (size_t)(originalSeqStore->sequences - originalSeqStore->sequencesStart)) { /* This accounts for possible last literals if the derived chunk reaches the end of the block */ - resultSeqStore->lit = litEnd; + assert(resultSeqStore->lit == originalSeqStore->lit); } else { - resultSeqStore->lit = resultSeqStore->litStart+literalsBytes; + size_t const literalsBytes = ZSTD_countSeqStoreLiteralsBytes(resultSeqStore); + resultSeqStore->lit = resultSeqStore->litStart + literalsBytes; } resultSeqStore->llCode += startIdx; resultSeqStore->mlCode += startIdx; @@ -3342,20 +3945,26 @@ static void ZSTD_deriveSeqStoreChunk(seqStore_t* resultSeqStore, } /* - * Returns the raw offset represented by the combination of offCode, ll0, and repcode history. - * offCode must represent a repcode in the numeric representation of ZSTD_storeSeq(). + * Returns the raw offset represented by the combination of offBase, ll0, and repcode history. + * offBase must represent a repcode in the numeric representation of ZSTD_storeSeq(). */ static U32 -ZSTD_resolveRepcodeToRawOffset(const U32 rep[ZSTD_REP_NUM], const U32 offCode, const U32 ll0) -{ - U32 const adjustedOffCode = STORED_REPCODE(offCode) - 1 + ll0; /* [ 0 - 3 ] */ - assert(STORED_IS_REPCODE(offCode)); - if (adjustedOffCode == ZSTD_REP_NUM) { - /* litlength == 0 and offCode == 2 implies selection of first repcode - 1 */ - assert(rep[0] > 0); +ZSTD_resolveRepcodeToRawOffset(const U32 rep[ZSTD_REP_NUM], const U32 offBase, const U32 ll0) +{ + U32 const adjustedRepCode = OFFBASE_TO_REPCODE(offBase) - 1 + ll0; /* [ 0 - 3 ] */ + assert(OFFBASE_IS_REPCODE(offBase)); + if (adjustedRepCode == ZSTD_REP_NUM) { + assert(ll0); + /* litlength == 0 and offCode == 2 implies selection of first repcode - 1 + * This is only valid if it results in a valid offset value, aka > 0. + * Note : it may happen that `rep[0]==1` in exceptional circumstances. + * In which case this function will return 0, which is an invalid offset. + * It's not an issue though, since this value will be + * compared and discarded within ZSTD_seqStore_resolveOffCodes(). + */ return rep[0] - 1; } - return rep[adjustedOffCode]; + return rep[adjustedRepCode]; } /* @@ -3371,30 +3980,33 @@ ZSTD_resolveRepcodeToRawOffset(const U32 rep[ZSTD_REP_NUM], const U32 offCode, c * 1-3 : repcode 1-3 * 4+ : real_offset+3 */ -static void ZSTD_seqStore_resolveOffCodes(repcodes_t* const dRepcodes, repcodes_t* const cRepcodes, - seqStore_t* const seqStore, U32 const nbSeq) { +static void +ZSTD_seqStore_resolveOffCodes(Repcodes_t* const dRepcodes, Repcodes_t* const cRepcodes, + const SeqStore_t* const seqStore, U32 const nbSeq) +{ U32 idx = 0; + U32 const longLitLenIdx = seqStore->longLengthType == ZSTD_llt_literalLength ? seqStore->longLengthPos : nbSeq; for (; idx < nbSeq; ++idx) { - seqDef* const seq = seqStore->sequencesStart + idx; - U32 const ll0 = (seq->litLength == 0); - U32 const offCode = OFFBASE_TO_STORED(seq->offBase); - assert(seq->offBase > 0); - if (STORED_IS_REPCODE(offCode)) { - U32 const dRawOffset = ZSTD_resolveRepcodeToRawOffset(dRepcodes->rep, offCode, ll0); - U32 const cRawOffset = ZSTD_resolveRepcodeToRawOffset(cRepcodes->rep, offCode, ll0); + SeqDef* const seq = seqStore->sequencesStart + idx; + U32 const ll0 = (seq->litLength == 0) && (idx != longLitLenIdx); + U32 const offBase = seq->offBase; + assert(offBase > 0); + if (OFFBASE_IS_REPCODE(offBase)) { + U32 const dRawOffset = ZSTD_resolveRepcodeToRawOffset(dRepcodes->rep, offBase, ll0); + U32 const cRawOffset = ZSTD_resolveRepcodeToRawOffset(cRepcodes->rep, offBase, ll0); /* Adjust simulated decompression repcode history if we come across a mismatch. Replace * the repcode with the offset it actually references, determined by the compression * repcode history. */ if (dRawOffset != cRawOffset) { - seq->offBase = cRawOffset + ZSTD_REP_NUM; + seq->offBase = OFFSET_TO_OFFBASE(cRawOffset); } } /* Compression repcode history is always updated with values directly from the unmodified seqStore. * Decompression repcode history may use modified seq->offset value taken from compression repcode history. */ - ZSTD_updateRep(dRepcodes->rep, OFFBASE_TO_STORED(seq->offBase), ll0); - ZSTD_updateRep(cRepcodes->rep, offCode, ll0); + ZSTD_updateRep(dRepcodes->rep, seq->offBase, ll0); + ZSTD_updateRep(cRepcodes->rep, offBase, ll0); } } @@ -3404,10 +4016,11 @@ static void ZSTD_seqStore_resolveOffCodes(repcodes_t* const dRepcodes, repcodes_ * Returns the total size of that block (including header) or a ZSTD error code. */ static size_t -ZSTD_compressSeqStore_singleBlock(ZSTD_CCtx* zc, seqStore_t* const seqStore, - repcodes_t* const dRep, repcodes_t* const cRep, +ZSTD_compressSeqStore_singleBlock(ZSTD_CCtx* zc, + const SeqStore_t* const seqStore, + Repcodes_t* const dRep, Repcodes_t* const cRep, void* dst, size_t dstCapacity, - const void* src, size_t srcSize, + const void* src, size_t srcSize, U32 lastBlock, U32 isPartition) { const U32 rleMaxLength = 25; @@ -3417,7 +4030,7 @@ ZSTD_compressSeqStore_singleBlock(ZSTD_CCtx* zc, seqStore_t* const seqStore, size_t cSeqsSize; /* In case of an RLE or raw block, the simulated decompression repcode history must be reset */ - repcodes_t const dRepOriginal = *dRep; + Repcodes_t const dRepOriginal = *dRep; DEBUGLOG(5, "ZSTD_compressSeqStore_singleBlock"); if (isPartition) ZSTD_seqStore_resolveOffCodes(dRep, cRep, seqStore, (U32)(seqStore->sequences - seqStore->sequencesStart)); @@ -3428,7 +4041,7 @@ ZSTD_compressSeqStore_singleBlock(ZSTD_CCtx* zc, seqStore_t* const seqStore, &zc->appliedParams, op + ZSTD_blockHeaderSize, dstCapacity - ZSTD_blockHeaderSize, srcSize, - zc->entropyWorkspace, ENTROPY_WORKSPACE_SIZE /* statically allocated in resetCCtx */, + zc->tmpWorkspace, zc->tmpWkspSize /* statically allocated in resetCCtx */, zc->bmi2); FORWARD_IF_ERROR(cSeqsSize, "ZSTD_entropyCompressSeqStore failed!"); @@ -3442,8 +4055,9 @@ ZSTD_compressSeqStore_singleBlock(ZSTD_CCtx* zc, seqStore_t* const seqStore, cSeqsSize = 1; } + /* Sequence collection not supported when block splitting */ if (zc->seqCollector.collectSequences) { - ZSTD_copyBlockSequences(zc); + FORWARD_IF_ERROR(ZSTD_copyBlockSequences(&zc->seqCollector, seqStore, dRepOriginal.rep), "copyBlockSequences failed"); ZSTD_blockState_confirmRepcodesAndEntropyTables(&zc->blockState); return 0; } @@ -3451,18 +4065,18 @@ ZSTD_compressSeqStore_singleBlock(ZSTD_CCtx* zc, seqStore_t* const seqStore, if (cSeqsSize == 0) { cSize = ZSTD_noCompressBlock(op, dstCapacity, ip, srcSize, lastBlock); FORWARD_IF_ERROR(cSize, "Nocompress block failed"); - DEBUGLOG(4, "Writing out nocompress block, size: %zu", cSize); + DEBUGLOG(5, "Writing out nocompress block, size: %zu", cSize); *dRep = dRepOriginal; /* reset simulated decompression repcode history */ } else if (cSeqsSize == 1) { cSize = ZSTD_rleCompressBlock(op, dstCapacity, *ip, srcSize, lastBlock); FORWARD_IF_ERROR(cSize, "RLE compress block failed"); - DEBUGLOG(4, "Writing out RLE block, size: %zu", cSize); + DEBUGLOG(5, "Writing out RLE block, size: %zu", cSize); *dRep = dRepOriginal; /* reset simulated decompression repcode history */ } else { ZSTD_blockState_confirmRepcodesAndEntropyTables(&zc->blockState); writeBlockHeader(op, cSeqsSize, srcSize, lastBlock); cSize = ZSTD_blockHeaderSize + cSeqsSize; - DEBUGLOG(4, "Writing out compressed block, size: %zu", cSize); + DEBUGLOG(5, "Writing out compressed block, size: %zu", cSize); } if (zc->blockState.prevCBlock->entropy.fse.offcode_repeatMode == FSE_repeat_valid) @@ -3481,45 +4095,49 @@ typedef struct { /* Helper function to perform the recursive search for block splits. * Estimates the cost of seqStore prior to split, and estimates the cost of splitting the sequences in half. - * If advantageous to split, then we recurse down the two sub-blocks. If not, or if an error occurred in estimation, then - * we do not recurse. + * If advantageous to split, then we recurse down the two sub-blocks. + * If not, or if an error occurred in estimation, then we do not recurse. * - * Note: The recursion depth is capped by a heuristic minimum number of sequences, defined by MIN_SEQUENCES_BLOCK_SPLITTING. + * Note: The recursion depth is capped by a heuristic minimum number of sequences, + * defined by MIN_SEQUENCES_BLOCK_SPLITTING. * In theory, this means the absolute largest recursion depth is 10 == log2(maxNbSeqInBlock/MIN_SEQUENCES_BLOCK_SPLITTING). * In practice, recursion depth usually doesn't go beyond 4. * - * Furthermore, the number of splits is capped by ZSTD_MAX_NB_BLOCK_SPLITS. At ZSTD_MAX_NB_BLOCK_SPLITS == 196 with the current existing blockSize + * Furthermore, the number of splits is capped by ZSTD_MAX_NB_BLOCK_SPLITS. + * At ZSTD_MAX_NB_BLOCK_SPLITS == 196 with the current existing blockSize * maximum of 128 KB, this value is actually impossible to reach. */ static void ZSTD_deriveBlockSplitsHelper(seqStoreSplits* splits, size_t startIdx, size_t endIdx, - ZSTD_CCtx* zc, const seqStore_t* origSeqStore) + ZSTD_CCtx* zc, const SeqStore_t* origSeqStore) { - seqStore_t* fullSeqStoreChunk = &zc->blockSplitCtx.fullSeqStoreChunk; - seqStore_t* firstHalfSeqStore = &zc->blockSplitCtx.firstHalfSeqStore; - seqStore_t* secondHalfSeqStore = &zc->blockSplitCtx.secondHalfSeqStore; + SeqStore_t* const fullSeqStoreChunk = &zc->blockSplitCtx.fullSeqStoreChunk; + SeqStore_t* const firstHalfSeqStore = &zc->blockSplitCtx.firstHalfSeqStore; + SeqStore_t* const secondHalfSeqStore = &zc->blockSplitCtx.secondHalfSeqStore; size_t estimatedOriginalSize; size_t estimatedFirstHalfSize; size_t estimatedSecondHalfSize; size_t midIdx = (startIdx + endIdx)/2; + DEBUGLOG(5, "ZSTD_deriveBlockSplitsHelper: startIdx=%zu endIdx=%zu", startIdx, endIdx); + assert(endIdx >= startIdx); if (endIdx - startIdx < MIN_SEQUENCES_BLOCK_SPLITTING || splits->idx >= ZSTD_MAX_NB_BLOCK_SPLITS) { - DEBUGLOG(6, "ZSTD_deriveBlockSplitsHelper: Too few sequences"); + DEBUGLOG(6, "ZSTD_deriveBlockSplitsHelper: Too few sequences (%zu)", endIdx - startIdx); return; } - DEBUGLOG(4, "ZSTD_deriveBlockSplitsHelper: startIdx=%zu endIdx=%zu", startIdx, endIdx); ZSTD_deriveSeqStoreChunk(fullSeqStoreChunk, origSeqStore, startIdx, endIdx); ZSTD_deriveSeqStoreChunk(firstHalfSeqStore, origSeqStore, startIdx, midIdx); ZSTD_deriveSeqStoreChunk(secondHalfSeqStore, origSeqStore, midIdx, endIdx); estimatedOriginalSize = ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize(fullSeqStoreChunk, zc); estimatedFirstHalfSize = ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize(firstHalfSeqStore, zc); estimatedSecondHalfSize = ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize(secondHalfSeqStore, zc); - DEBUGLOG(4, "Estimated original block size: %zu -- First half split: %zu -- Second half split: %zu", + DEBUGLOG(5, "Estimated original block size: %zu -- First half split: %zu -- Second half split: %zu", estimatedOriginalSize, estimatedFirstHalfSize, estimatedSecondHalfSize); if (ZSTD_isError(estimatedOriginalSize) || ZSTD_isError(estimatedFirstHalfSize) || ZSTD_isError(estimatedSecondHalfSize)) { return; } if (estimatedFirstHalfSize + estimatedSecondHalfSize < estimatedOriginalSize) { + DEBUGLOG(5, "split decided at seqNb:%zu", midIdx); ZSTD_deriveBlockSplitsHelper(splits, startIdx, midIdx, zc, origSeqStore); splits->splitLocations[splits->idx] = (U32)midIdx; splits->idx++; @@ -3527,14 +4145,18 @@ ZSTD_deriveBlockSplitsHelper(seqStoreSplits* splits, size_t startIdx, size_t end } } -/* Base recursive function. Populates a table with intra-block partition indices that can improve compression ratio. +/* Base recursive function. + * Populates a table with intra-block partition indices that can improve compression ratio. * - * Returns the number of splits made (which equals the size of the partition table - 1). + * @return: number of splits made (which equals the size of the partition table - 1). */ -static size_t ZSTD_deriveBlockSplits(ZSTD_CCtx* zc, U32 partitions[], U32 nbSeq) { - seqStoreSplits splits = {partitions, 0}; +static size_t ZSTD_deriveBlockSplits(ZSTD_CCtx* zc, U32 partitions[], U32 nbSeq) +{ + seqStoreSplits splits; + splits.splitLocations = partitions; + splits.idx = 0; if (nbSeq <= 4) { - DEBUGLOG(4, "ZSTD_deriveBlockSplits: Too few sequences to split"); + DEBUGLOG(5, "ZSTD_deriveBlockSplits: Too few sequences to split (%u <= 4)", nbSeq); /* Refuse to try and split anything with less than 4 sequences */ return 0; } @@ -3550,18 +4172,20 @@ static size_t ZSTD_deriveBlockSplits(ZSTD_CCtx* zc, U32 partitions[], U32 nbSeq) * Returns combined size of all blocks (which includes headers), or a ZSTD error code. */ static size_t -ZSTD_compressBlock_splitBlock_internal(ZSTD_CCtx* zc, void* dst, size_t dstCapacity, - const void* src, size_t blockSize, U32 lastBlock, U32 nbSeq) +ZSTD_compressBlock_splitBlock_internal(ZSTD_CCtx* zc, + void* dst, size_t dstCapacity, + const void* src, size_t blockSize, + U32 lastBlock, U32 nbSeq) { size_t cSize = 0; const BYTE* ip = (const BYTE*)src; BYTE* op = (BYTE*)dst; size_t i = 0; size_t srcBytesTotal = 0; - U32* partitions = zc->blockSplitCtx.partitions; /* size == ZSTD_MAX_NB_BLOCK_SPLITS */ - seqStore_t* nextSeqStore = &zc->blockSplitCtx.nextSeqStore; - seqStore_t* currSeqStore = &zc->blockSplitCtx.currSeqStore; - size_t numSplits = ZSTD_deriveBlockSplits(zc, partitions, nbSeq); + U32* const partitions = zc->blockSplitCtx.partitions; /* size == ZSTD_MAX_NB_BLOCK_SPLITS */ + SeqStore_t* const nextSeqStore = &zc->blockSplitCtx.nextSeqStore; + SeqStore_t* const currSeqStore = &zc->blockSplitCtx.currSeqStore; + size_t const numSplits = ZSTD_deriveBlockSplits(zc, partitions, nbSeq); /* If a block is split and some partitions are emitted as RLE/uncompressed, then repcode history * may become invalid. In order to reconcile potentially invalid repcodes, we keep track of two @@ -3577,36 +4201,37 @@ ZSTD_compressBlock_splitBlock_internal(ZSTD_CCtx* zc, void* dst, size_t dstCapac * * See ZSTD_seqStore_resolveOffCodes() for more details. */ - repcodes_t dRep; - repcodes_t cRep; - ZSTD_memcpy(dRep.rep, zc->blockState.prevCBlock->rep, sizeof(repcodes_t)); - ZSTD_memcpy(cRep.rep, zc->blockState.prevCBlock->rep, sizeof(repcodes_t)); - ZSTD_memset(nextSeqStore, 0, sizeof(seqStore_t)); + Repcodes_t dRep; + Repcodes_t cRep; + ZSTD_memcpy(dRep.rep, zc->blockState.prevCBlock->rep, sizeof(Repcodes_t)); + ZSTD_memcpy(cRep.rep, zc->blockState.prevCBlock->rep, sizeof(Repcodes_t)); + ZSTD_memset(nextSeqStore, 0, sizeof(SeqStore_t)); - DEBUGLOG(4, "ZSTD_compressBlock_splitBlock_internal (dstCapacity=%u, dictLimit=%u, nextToUpdate=%u)", + DEBUGLOG(5, "ZSTD_compressBlock_splitBlock_internal (dstCapacity=%u, dictLimit=%u, nextToUpdate=%u)", (unsigned)dstCapacity, (unsigned)zc->blockState.matchState.window.dictLimit, (unsigned)zc->blockState.matchState.nextToUpdate); if (numSplits == 0) { - size_t cSizeSingleBlock = ZSTD_compressSeqStore_singleBlock(zc, &zc->seqStore, - &dRep, &cRep, - op, dstCapacity, - ip, blockSize, - lastBlock, 0 /* isPartition */); + size_t cSizeSingleBlock = + ZSTD_compressSeqStore_singleBlock(zc, &zc->seqStore, + &dRep, &cRep, + op, dstCapacity, + ip, blockSize, + lastBlock, 0 /* isPartition */); FORWARD_IF_ERROR(cSizeSingleBlock, "Compressing single block from splitBlock_internal() failed!"); DEBUGLOG(5, "ZSTD_compressBlock_splitBlock_internal: No splits"); - assert(cSizeSingleBlock <= ZSTD_BLOCKSIZE_MAX + ZSTD_blockHeaderSize); + assert(zc->blockSizeMax <= ZSTD_BLOCKSIZE_MAX); + assert(cSizeSingleBlock <= zc->blockSizeMax + ZSTD_blockHeaderSize); return cSizeSingleBlock; } ZSTD_deriveSeqStoreChunk(currSeqStore, &zc->seqStore, 0, partitions[0]); for (i = 0; i <= numSplits; ++i) { - size_t srcBytes; size_t cSizeChunk; U32 const lastPartition = (i == numSplits); U32 lastBlockEntireSrc = 0; - srcBytes = ZSTD_countSeqStoreLiteralsBytes(currSeqStore) + ZSTD_countSeqStoreMatchBytes(currSeqStore); + size_t srcBytes = ZSTD_countSeqStoreLiteralsBytes(currSeqStore) + ZSTD_countSeqStoreMatchBytes(currSeqStore); srcBytesTotal += srcBytes; if (lastPartition) { /* This is the final partition, need to account for possible last literals */ @@ -3621,7 +4246,8 @@ ZSTD_compressBlock_splitBlock_internal(ZSTD_CCtx* zc, void* dst, size_t dstCapac op, dstCapacity, ip, srcBytes, lastBlockEntireSrc, 1 /* isPartition */); - DEBUGLOG(5, "Estimated size: %zu actual size: %zu", ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize(currSeqStore, zc), cSizeChunk); + DEBUGLOG(5, "Estimated size: %zu vs %zu : actual size", + ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize(currSeqStore, zc), cSizeChunk); FORWARD_IF_ERROR(cSizeChunk, "Compressing chunk failed!"); ip += srcBytes; @@ -3629,12 +4255,12 @@ ZSTD_compressBlock_splitBlock_internal(ZSTD_CCtx* zc, void* dst, size_t dstCapac dstCapacity -= cSizeChunk; cSize += cSizeChunk; *currSeqStore = *nextSeqStore; - assert(cSizeChunk <= ZSTD_BLOCKSIZE_MAX + ZSTD_blockHeaderSize); + assert(cSizeChunk <= zc->blockSizeMax + ZSTD_blockHeaderSize); } - /* cRep and dRep may have diverged during the compression. If so, we use the dRep repcodes - * for the next block. + /* cRep and dRep may have diverged during the compression. + * If so, we use the dRep repcodes for the next block. */ - ZSTD_memcpy(zc->blockState.prevCBlock->rep, dRep.rep, sizeof(repcodes_t)); + ZSTD_memcpy(zc->blockState.prevCBlock->rep, dRep.rep, sizeof(Repcodes_t)); return cSize; } @@ -3643,21 +4269,20 @@ ZSTD_compressBlock_splitBlock(ZSTD_CCtx* zc, void* dst, size_t dstCapacity, const void* src, size_t srcSize, U32 lastBlock) { - const BYTE* ip = (const BYTE*)src; - BYTE* op = (BYTE*)dst; U32 nbSeq; size_t cSize; - DEBUGLOG(4, "ZSTD_compressBlock_splitBlock"); - assert(zc->appliedParams.useBlockSplitter == ZSTD_ps_enable); + DEBUGLOG(5, "ZSTD_compressBlock_splitBlock"); + assert(zc->appliedParams.postBlockSplitter == ZSTD_ps_enable); { const size_t bss = ZSTD_buildSeqStore(zc, src, srcSize); FORWARD_IF_ERROR(bss, "ZSTD_buildSeqStore failed"); if (bss == ZSTDbss_noCompress) { if (zc->blockState.prevCBlock->entropy.fse.offcode_repeatMode == FSE_repeat_valid) zc->blockState.prevCBlock->entropy.fse.offcode_repeatMode = FSE_repeat_check; - cSize = ZSTD_noCompressBlock(op, dstCapacity, ip, srcSize, lastBlock); + RETURN_ERROR_IF(zc->seqCollector.collectSequences, sequenceProducer_failed, "Uncompressible block"); + cSize = ZSTD_noCompressBlock(dst, dstCapacity, src, srcSize, lastBlock); FORWARD_IF_ERROR(cSize, "ZSTD_noCompressBlock failed"); - DEBUGLOG(4, "ZSTD_compressBlock_splitBlock: Nocompress block"); + DEBUGLOG(5, "ZSTD_compressBlock_splitBlock: Nocompress block"); return cSize; } nbSeq = (U32)(zc->seqStore.sequences - zc->seqStore.sequencesStart); @@ -3673,9 +4298,9 @@ ZSTD_compressBlock_internal(ZSTD_CCtx* zc, void* dst, size_t dstCapacity, const void* src, size_t srcSize, U32 frame) { - /* This the upper bound for the length of an rle block. - * This isn't the actual upper bound. Finding the real threshold - * needs further investigation. + /* This is an estimated upper bound for the length of an rle block. + * This isn't the actual upper bound. + * Finding the real threshold needs further investigation. */ const U32 rleMaxLength = 25; size_t cSize; @@ -3687,11 +4312,15 @@ ZSTD_compressBlock_internal(ZSTD_CCtx* zc, { const size_t bss = ZSTD_buildSeqStore(zc, src, srcSize); FORWARD_IF_ERROR(bss, "ZSTD_buildSeqStore failed"); - if (bss == ZSTDbss_noCompress) { cSize = 0; goto out; } + if (bss == ZSTDbss_noCompress) { + RETURN_ERROR_IF(zc->seqCollector.collectSequences, sequenceProducer_failed, "Uncompressible block"); + cSize = 0; + goto out; + } } if (zc->seqCollector.collectSequences) { - ZSTD_copyBlockSequences(zc); + FORWARD_IF_ERROR(ZSTD_copyBlockSequences(&zc->seqCollector, ZSTD_getSeqStore(zc), zc->blockState.prevCBlock->rep), "copyBlockSequences failed"); ZSTD_blockState_confirmRepcodesAndEntropyTables(&zc->blockState); return 0; } @@ -3702,7 +4331,7 @@ ZSTD_compressBlock_internal(ZSTD_CCtx* zc, &zc->appliedParams, dst, dstCapacity, srcSize, - zc->entropyWorkspace, ENTROPY_WORKSPACE_SIZE /* statically allocated in resetCCtx */, + zc->tmpWorkspace, zc->tmpWkspSize /* statically allocated in resetCCtx */, zc->bmi2); if (frame && @@ -3767,10 +4396,11 @@ static size_t ZSTD_compressBlock_targetCBlockSize_body(ZSTD_CCtx* zc, * * cSize >= blockBound(srcSize): We have expanded the block too much so * emit an uncompressed block. */ - { - size_t const cSize = ZSTD_compressSuperBlock(zc, dst, dstCapacity, src, srcSize, lastBlock); + { size_t const cSize = + ZSTD_compressSuperBlock(zc, dst, dstCapacity, src, srcSize, lastBlock); if (cSize != ERROR(dstSize_tooSmall)) { - size_t const maxCSize = srcSize - ZSTD_minGain(srcSize, zc->appliedParams.cParams.strategy); + size_t const maxCSize = + srcSize - ZSTD_minGain(srcSize, zc->appliedParams.cParams.strategy); FORWARD_IF_ERROR(cSize, "ZSTD_compressSuperBlock failed"); if (cSize != 0 && cSize < maxCSize + ZSTD_blockHeaderSize) { ZSTD_blockState_confirmRepcodesAndEntropyTables(&zc->blockState); @@ -3778,7 +4408,7 @@ static size_t ZSTD_compressBlock_targetCBlockSize_body(ZSTD_CCtx* zc, } } } - } + } /* if (bss == ZSTDbss_compress)*/ DEBUGLOG(6, "Resorting to ZSTD_noCompressBlock()"); /* Superblock compression failed, attempt to emit a single no compress block. @@ -3807,7 +4437,7 @@ static size_t ZSTD_compressBlock_targetCBlockSize(ZSTD_CCtx* zc, return cSize; } -static void ZSTD_overflowCorrectIfNeeded(ZSTD_matchState_t* ms, +static void ZSTD_overflowCorrectIfNeeded(ZSTD_MatchState_t* ms, ZSTD_cwksp* ws, ZSTD_CCtx_params const* params, void const* ip, @@ -3831,39 +4461,82 @@ static void ZSTD_overflowCorrectIfNeeded(ZSTD_matchState_t* ms, } } +#include "zstd_preSplit.h" + +static size_t ZSTD_optimalBlockSize(ZSTD_CCtx* cctx, const void* src, size_t srcSize, size_t blockSizeMax, int splitLevel, ZSTD_strategy strat, S64 savings) +{ + /* split level based on compression strategy, from `fast` to `btultra2` */ + static const int splitLevels[] = { 0, 0, 1, 2, 2, 3, 3, 4, 4, 4 }; + /* note: conservatively only split full blocks (128 KB) currently. + * While it's possible to go lower, let's keep it simple for a first implementation. + * Besides, benefits of splitting are reduced when blocks are already small. + */ + if (srcSize < 128 KB || blockSizeMax < 128 KB) + return MIN(srcSize, blockSizeMax); + /* do not split incompressible data though: + * require verified savings to allow pre-splitting. + * Note: as a consequence, the first full block is not split. + */ + if (savings < 3) { + DEBUGLOG(6, "don't attempt splitting: savings (%i) too low", (int)savings); + return 128 KB; + } + /* apply @splitLevel, or use default value (which depends on @strat). + * note that splitting heuristic is still conditioned by @savings >= 3, + * so the first block will not reach this code path */ + if (splitLevel == 1) return 128 KB; + if (splitLevel == 0) { + assert(ZSTD_fast <= strat && strat <= ZSTD_btultra2); + splitLevel = splitLevels[strat]; + } else { + assert(2 <= splitLevel && splitLevel <= 6); + splitLevel -= 2; + } + return ZSTD_splitBlock(src, blockSizeMax, splitLevel, cctx->tmpWorkspace, cctx->tmpWkspSize); +} + /*! ZSTD_compress_frameChunk() : * Compress a chunk of data into one or multiple blocks. * All blocks will be terminated, all input will be consumed. * Function will issue an error if there is not enough `dstCapacity` to hold the compressed content. * Frame is supposed already started (header already produced) -* @return : compressed size, or an error code +* @return : compressed size, or an error code */ static size_t ZSTD_compress_frameChunk(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize, U32 lastFrameChunk) { - size_t blockSize = cctx->blockSize; + size_t blockSizeMax = cctx->blockSizeMax; size_t remaining = srcSize; const BYTE* ip = (const BYTE*)src; BYTE* const ostart = (BYTE*)dst; BYTE* op = ostart; U32 const maxDist = (U32)1 << cctx->appliedParams.cParams.windowLog; + S64 savings = (S64)cctx->consumedSrcSize - (S64)cctx->producedCSize; assert(cctx->appliedParams.cParams.windowLog <= ZSTD_WINDOWLOG_MAX); - DEBUGLOG(4, "ZSTD_compress_frameChunk (blockSize=%u)", (unsigned)blockSize); + DEBUGLOG(5, "ZSTD_compress_frameChunk (srcSize=%u, blockSizeMax=%u)", (unsigned)srcSize, (unsigned)blockSizeMax); if (cctx->appliedParams.fParams.checksumFlag && srcSize) xxh64_update(&cctx->xxhState, src, srcSize); while (remaining) { - ZSTD_matchState_t* const ms = &cctx->blockState.matchState; - U32 const lastBlock = lastFrameChunk & (blockSize >= remaining); - - RETURN_ERROR_IF(dstCapacity < ZSTD_blockHeaderSize + MIN_CBLOCK_SIZE, + ZSTD_MatchState_t* const ms = &cctx->blockState.matchState; + size_t const blockSize = ZSTD_optimalBlockSize(cctx, + ip, remaining, + blockSizeMax, + cctx->appliedParams.preBlockSplitter_level, + cctx->appliedParams.cParams.strategy, + savings); + U32 const lastBlock = lastFrameChunk & (blockSize == remaining); + assert(blockSize <= remaining); + + /* TODO: See 3090. We reduced MIN_CBLOCK_SIZE from 3 to 2 so to compensate we are adding + * additional 1. We need to revisit and change this logic to be more consistent */ + RETURN_ERROR_IF(dstCapacity < ZSTD_blockHeaderSize + MIN_CBLOCK_SIZE + 1, dstSize_tooSmall, "not enough space to store compressed block"); - if (remaining < blockSize) blockSize = remaining; ZSTD_overflowCorrectIfNeeded( ms, &cctx->workspace, &cctx->appliedParams, ip, ip + blockSize); @@ -3899,8 +4572,23 @@ static size_t ZSTD_compress_frameChunk(ZSTD_CCtx* cctx, MEM_writeLE24(op, cBlockHeader); cSize += ZSTD_blockHeaderSize; } - } - + } /* if (ZSTD_useTargetCBlockSize(&cctx->appliedParams))*/ + + /* @savings is employed to ensure that splitting doesn't worsen expansion of incompressible data. + * Without splitting, the maximum expansion is 3 bytes per full block. + * An adversarial input could attempt to fudge the split detector, + * and make it split incompressible data, resulting in more block headers. + * Note that, since ZSTD_COMPRESSBOUND() assumes a worst case scenario of 1KB per block, + * and the splitter never creates blocks that small (current lower limit is 8 KB), + * there is already no risk to expand beyond ZSTD_COMPRESSBOUND() limit. + * But if the goal is to not expand by more than 3-bytes per 128 KB full block, + * then yes, it becomes possible to make the block splitter oversplit incompressible data. + * Using @savings, we enforce an even more conservative condition, + * requiring the presence of enough savings (at least 3 bytes) to authorize splitting, + * otherwise only full blocks are used. + * But being conservative is fine, + * since splitting barely compressible blocks is not fruitful anyway */ + savings += (S64)blockSize - (S64)cSize; ip += blockSize; assert(remaining >= blockSize); @@ -3919,8 +4607,10 @@ static size_t ZSTD_compress_frameChunk(ZSTD_CCtx* cctx, static size_t ZSTD_writeFrameHeader(void* dst, size_t dstCapacity, - const ZSTD_CCtx_params* params, U64 pledgedSrcSize, U32 dictID) -{ BYTE* const op = (BYTE*)dst; + const ZSTD_CCtx_params* params, + U64 pledgedSrcSize, U32 dictID) +{ + BYTE* const op = (BYTE*)dst; U32 const dictIDSizeCodeLength = (dictID>0) + (dictID>=256) + (dictID>=65536); /* 0-3 */ U32 const dictIDSizeCode = params->fParams.noDictIDFlag ? 0 : dictIDSizeCodeLength; /* 0-3 */ U32 const checksumFlag = params->fParams.checksumFlag>0; @@ -4001,19 +4691,15 @@ size_t ZSTD_writeLastEmptyBlock(void* dst, size_t dstCapacity) } } -size_t ZSTD_referenceExternalSequences(ZSTD_CCtx* cctx, rawSeq* seq, size_t nbSeq) +void ZSTD_referenceExternalSequences(ZSTD_CCtx* cctx, rawSeq* seq, size_t nbSeq) { - RETURN_ERROR_IF(cctx->stage != ZSTDcs_init, stage_wrong, - "wrong cctx stage"); - RETURN_ERROR_IF(cctx->appliedParams.ldmParams.enableLdm == ZSTD_ps_enable, - parameter_unsupported, - "incompatible with ldm"); + assert(cctx->stage == ZSTDcs_init); + assert(nbSeq == 0 || cctx->appliedParams.ldmParams.enableLdm != ZSTD_ps_enable); cctx->externSeqStore.seq = seq; cctx->externSeqStore.size = nbSeq; cctx->externSeqStore.capacity = nbSeq; cctx->externSeqStore.pos = 0; cctx->externSeqStore.posInSequence = 0; - return 0; } @@ -4022,7 +4708,7 @@ static size_t ZSTD_compressContinue_internal (ZSTD_CCtx* cctx, const void* src, size_t srcSize, U32 frame, U32 lastFrameChunk) { - ZSTD_matchState_t* const ms = &cctx->blockState.matchState; + ZSTD_MatchState_t* const ms = &cctx->blockState.matchState; size_t fhSize = 0; DEBUGLOG(5, "ZSTD_compressContinue_internal, stage: %u, srcSize: %u", @@ -4057,7 +4743,7 @@ static size_t ZSTD_compressContinue_internal (ZSTD_CCtx* cctx, src, (BYTE const*)src + srcSize); } - DEBUGLOG(5, "ZSTD_compressContinue_internal (blockSize=%u)", (unsigned)cctx->blockSize); + DEBUGLOG(5, "ZSTD_compressContinue_internal (blockSize=%u)", (unsigned)cctx->blockSizeMax); { size_t const cSize = frame ? ZSTD_compress_frameChunk (cctx, dst, dstCapacity, src, srcSize, lastFrameChunk) : ZSTD_compressBlock_internal (cctx, dst, dstCapacity, src, srcSize, 0 /* frame */); @@ -4078,58 +4764,90 @@ static size_t ZSTD_compressContinue_internal (ZSTD_CCtx* cctx, } } -size_t ZSTD_compressContinue (ZSTD_CCtx* cctx, - void* dst, size_t dstCapacity, - const void* src, size_t srcSize) +size_t ZSTD_compressContinue_public(ZSTD_CCtx* cctx, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize) { DEBUGLOG(5, "ZSTD_compressContinue (srcSize=%u)", (unsigned)srcSize); return ZSTD_compressContinue_internal(cctx, dst, dstCapacity, src, srcSize, 1 /* frame mode */, 0 /* last chunk */); } +/* NOTE: Must just wrap ZSTD_compressContinue_public() */ +size_t ZSTD_compressContinue(ZSTD_CCtx* cctx, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize) +{ + return ZSTD_compressContinue_public(cctx, dst, dstCapacity, src, srcSize); +} -size_t ZSTD_getBlockSize(const ZSTD_CCtx* cctx) +static size_t ZSTD_getBlockSize_deprecated(const ZSTD_CCtx* cctx) { ZSTD_compressionParameters const cParams = cctx->appliedParams.cParams; assert(!ZSTD_checkCParams(cParams)); - return MIN (ZSTD_BLOCKSIZE_MAX, (U32)1 << cParams.windowLog); + return MIN(cctx->appliedParams.maxBlockSize, (size_t)1 << cParams.windowLog); } -size_t ZSTD_compressBlock(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize) +/* NOTE: Must just wrap ZSTD_getBlockSize_deprecated() */ +size_t ZSTD_getBlockSize(const ZSTD_CCtx* cctx) +{ + return ZSTD_getBlockSize_deprecated(cctx); +} + +/* NOTE: Must just wrap ZSTD_compressBlock_deprecated() */ +size_t ZSTD_compressBlock_deprecated(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize) { DEBUGLOG(5, "ZSTD_compressBlock: srcSize = %u", (unsigned)srcSize); - { size_t const blockSizeMax = ZSTD_getBlockSize(cctx); + { size_t const blockSizeMax = ZSTD_getBlockSize_deprecated(cctx); RETURN_ERROR_IF(srcSize > blockSizeMax, srcSize_wrong, "input is larger than a block"); } return ZSTD_compressContinue_internal(cctx, dst, dstCapacity, src, srcSize, 0 /* frame mode */, 0 /* last chunk */); } +/* NOTE: Must just wrap ZSTD_compressBlock_deprecated() */ +size_t ZSTD_compressBlock(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize) +{ + return ZSTD_compressBlock_deprecated(cctx, dst, dstCapacity, src, srcSize); +} + /*! ZSTD_loadDictionaryContent() : * @return : 0, or an error code */ -static size_t ZSTD_loadDictionaryContent(ZSTD_matchState_t* ms, - ldmState_t* ls, - ZSTD_cwksp* ws, - ZSTD_CCtx_params const* params, - const void* src, size_t srcSize, - ZSTD_dictTableLoadMethod_e dtlm) +static size_t +ZSTD_loadDictionaryContent(ZSTD_MatchState_t* ms, + ldmState_t* ls, + ZSTD_cwksp* ws, + ZSTD_CCtx_params const* params, + const void* src, size_t srcSize, + ZSTD_dictTableLoadMethod_e dtlm, + ZSTD_tableFillPurpose_e tfp) { const BYTE* ip = (const BYTE*) src; const BYTE* const iend = ip + srcSize; int const loadLdmDict = params->ldmParams.enableLdm == ZSTD_ps_enable && ls != NULL; - /* Assert that we the ms params match the params we're being given */ + /* Assert that the ms params match the params we're being given */ ZSTD_assertEqualCParams(params->cParams, ms->cParams); - if (srcSize > ZSTD_CHUNKSIZE_MAX) { + { /* Ensure large dictionaries can't cause index overflow */ + /* Allow the dictionary to set indices up to exactly ZSTD_CURRENT_MAX. * Dictionaries right at the edge will immediately trigger overflow * correction, but I don't want to insert extra constraints here. */ - U32 const maxDictSize = ZSTD_CURRENT_MAX - 1; - /* We must have cleared our windows when our source is this large. */ - assert(ZSTD_window_isEmpty(ms->window)); - if (loadLdmDict) - assert(ZSTD_window_isEmpty(ls->window)); + U32 maxDictSize = ZSTD_CURRENT_MAX - ZSTD_WINDOW_START_INDEX; + + int const CDictTaggedIndices = ZSTD_CDictIndicesAreTagged(¶ms->cParams); + if (CDictTaggedIndices && tfp == ZSTD_tfp_forCDict) { + /* Some dictionary matchfinders in zstd use "short cache", + * which treats the lower ZSTD_SHORT_CACHE_TAG_BITS of each + * CDict hashtable entry as a tag rather than as part of an index. + * When short cache is used, we need to truncate the dictionary + * so that its indices don't overlap with the tag. */ + U32 const shortCacheMaxDictSize = (1u << (32 - ZSTD_SHORT_CACHE_TAG_BITS)) - ZSTD_WINDOW_START_INDEX; + maxDictSize = MIN(maxDictSize, shortCacheMaxDictSize); + assert(!loadLdmDict); + } + /* If the dictionary is too large, only load the suffix of the dictionary. */ if (srcSize > maxDictSize) { ip = iend - maxDictSize; @@ -4138,35 +4856,59 @@ static size_t ZSTD_loadDictionaryContent(ZSTD_matchState_t* ms, } } - DEBUGLOG(4, "ZSTD_loadDictionaryContent(): useRowMatchFinder=%d", (int)params->useRowMatchFinder); + if (srcSize > ZSTD_CHUNKSIZE_MAX) { + /* We must have cleared our windows when our source is this large. */ + assert(ZSTD_window_isEmpty(ms->window)); + if (loadLdmDict) assert(ZSTD_window_isEmpty(ls->window)); + } ZSTD_window_update(&ms->window, src, srcSize, /* forceNonContiguous */ 0); - ms->loadedDictEnd = params->forceWindow ? 0 : (U32)(iend - ms->window.base); - ms->forceNonContiguous = params->deterministicRefPrefix; - if (loadLdmDict) { + DEBUGLOG(4, "ZSTD_loadDictionaryContent: useRowMatchFinder=%d", (int)params->useRowMatchFinder); + + if (loadLdmDict) { /* Load the entire dict into LDM matchfinders. */ + DEBUGLOG(4, "ZSTD_loadDictionaryContent: Trigger loadLdmDict"); ZSTD_window_update(&ls->window, src, srcSize, /* forceNonContiguous */ 0); ls->loadedDictEnd = params->forceWindow ? 0 : (U32)(iend - ls->window.base); + ZSTD_ldm_fillHashTable(ls, ip, iend, ¶ms->ldmParams); + DEBUGLOG(4, "ZSTD_loadDictionaryContent: ZSTD_ldm_fillHashTable completes"); + } + + /* If the dict is larger than we can reasonably index in our tables, only load the suffix. */ + { U32 maxDictSize = 1U << MIN(MAX(params->cParams.hashLog + 3, params->cParams.chainLog + 1), 31); + if (srcSize > maxDictSize) { + ip = iend - maxDictSize; + src = ip; + srcSize = maxDictSize; + } } + ms->nextToUpdate = (U32)(ip - ms->window.base); + ms->loadedDictEnd = params->forceWindow ? 0 : (U32)(iend - ms->window.base); + ms->forceNonContiguous = params->deterministicRefPrefix; + if (srcSize <= HASH_READ_SIZE) return 0; ZSTD_overflowCorrectIfNeeded(ms, ws, params, ip, iend); - if (loadLdmDict) - ZSTD_ldm_fillHashTable(ls, ip, iend, ¶ms->ldmParams); - switch(params->cParams.strategy) { case ZSTD_fast: - ZSTD_fillHashTable(ms, iend, dtlm); + ZSTD_fillHashTable(ms, iend, dtlm, tfp); break; case ZSTD_dfast: - ZSTD_fillDoubleHashTable(ms, iend, dtlm); +#ifndef ZSTD_EXCLUDE_DFAST_BLOCK_COMPRESSOR + ZSTD_fillDoubleHashTable(ms, iend, dtlm, tfp); +#else + assert(0); /* shouldn't be called: cparams should've been adjusted. */ +#endif break; case ZSTD_greedy: case ZSTD_lazy: case ZSTD_lazy2: +#if !defined(ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR) assert(srcSize >= HASH_READ_SIZE); if (ms->dedicatedDictSearch) { assert(ms->chainTable != NULL); @@ -4174,7 +4916,7 @@ static size_t ZSTD_loadDictionaryContent(ZSTD_matchState_t* ms, } else { assert(params->useRowMatchFinder != ZSTD_ps_auto); if (params->useRowMatchFinder == ZSTD_ps_enable) { - size_t const tagTableSize = ((size_t)1 << params->cParams.hashLog) * sizeof(U16); + size_t const tagTableSize = ((size_t)1 << params->cParams.hashLog); ZSTD_memset(ms->tagTable, 0, tagTableSize); ZSTD_row_update(ms, iend-HASH_READ_SIZE); DEBUGLOG(4, "Using row-based hash table for lazy dict"); @@ -4183,14 +4925,24 @@ static size_t ZSTD_loadDictionaryContent(ZSTD_matchState_t* ms, DEBUGLOG(4, "Using chain-based hash table for lazy dict"); } } +#else + assert(0); /* shouldn't be called: cparams should've been adjusted. */ +#endif break; case ZSTD_btlazy2: /* we want the dictionary table fully sorted */ case ZSTD_btopt: case ZSTD_btultra: case ZSTD_btultra2: +#if !defined(ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR) assert(srcSize >= HASH_READ_SIZE); + DEBUGLOG(4, "Fill %u bytes into the Binary Tree", (unsigned)srcSize); ZSTD_updateTree(ms, iend-HASH_READ_SIZE, iend); +#else + assert(0); /* shouldn't be called: cparams should've been adjusted. */ +#endif break; default: @@ -4233,20 +4985,19 @@ size_t ZSTD_loadCEntropy(ZSTD_compressedBlockState_t* bs, void* workspace, { unsigned maxSymbolValue = 255; unsigned hasZeroWeights = 1; size_t const hufHeaderSize = HUF_readCTable((HUF_CElt*)bs->entropy.huf.CTable, &maxSymbolValue, dictPtr, - dictEnd-dictPtr, &hasZeroWeights); + (size_t)(dictEnd-dictPtr), &hasZeroWeights); /* We only set the loaded table as valid if it contains all non-zero * weights. Otherwise, we set it to check */ - if (!hasZeroWeights) + if (!hasZeroWeights && maxSymbolValue == 255) bs->entropy.huf.repeatMode = HUF_repeat_valid; RETURN_ERROR_IF(HUF_isError(hufHeaderSize), dictionary_corrupted, ""); - RETURN_ERROR_IF(maxSymbolValue < 255, dictionary_corrupted, ""); dictPtr += hufHeaderSize; } { unsigned offcodeLog; - size_t const offcodeHeaderSize = FSE_readNCount(offcodeNCount, &offcodeMaxValue, &offcodeLog, dictPtr, dictEnd-dictPtr); + size_t const offcodeHeaderSize = FSE_readNCount(offcodeNCount, &offcodeMaxValue, &offcodeLog, dictPtr, (size_t)(dictEnd-dictPtr)); RETURN_ERROR_IF(FSE_isError(offcodeHeaderSize), dictionary_corrupted, ""); RETURN_ERROR_IF(offcodeLog > OffFSELog, dictionary_corrupted, ""); /* fill all offset symbols to avoid garbage at end of table */ @@ -4261,7 +5012,7 @@ size_t ZSTD_loadCEntropy(ZSTD_compressedBlockState_t* bs, void* workspace, { short matchlengthNCount[MaxML+1]; unsigned matchlengthMaxValue = MaxML, matchlengthLog; - size_t const matchlengthHeaderSize = FSE_readNCount(matchlengthNCount, &matchlengthMaxValue, &matchlengthLog, dictPtr, dictEnd-dictPtr); + size_t const matchlengthHeaderSize = FSE_readNCount(matchlengthNCount, &matchlengthMaxValue, &matchlengthLog, dictPtr, (size_t)(dictEnd-dictPtr)); RETURN_ERROR_IF(FSE_isError(matchlengthHeaderSize), dictionary_corrupted, ""); RETURN_ERROR_IF(matchlengthLog > MLFSELog, dictionary_corrupted, ""); RETURN_ERROR_IF(FSE_isError(FSE_buildCTable_wksp( @@ -4275,7 +5026,7 @@ size_t ZSTD_loadCEntropy(ZSTD_compressedBlockState_t* bs, void* workspace, { short litlengthNCount[MaxLL+1]; unsigned litlengthMaxValue = MaxLL, litlengthLog; - size_t const litlengthHeaderSize = FSE_readNCount(litlengthNCount, &litlengthMaxValue, &litlengthLog, dictPtr, dictEnd-dictPtr); + size_t const litlengthHeaderSize = FSE_readNCount(litlengthNCount, &litlengthMaxValue, &litlengthLog, dictPtr, (size_t)(dictEnd-dictPtr)); RETURN_ERROR_IF(FSE_isError(litlengthHeaderSize), dictionary_corrupted, ""); RETURN_ERROR_IF(litlengthLog > LLFSELog, dictionary_corrupted, ""); RETURN_ERROR_IF(FSE_isError(FSE_buildCTable_wksp( @@ -4309,7 +5060,7 @@ size_t ZSTD_loadCEntropy(ZSTD_compressedBlockState_t* bs, void* workspace, RETURN_ERROR_IF(bs->rep[u] > dictContentSize, dictionary_corrupted, ""); } } } - return dictPtr - (const BYTE*)dict; + return (size_t)(dictPtr - (const BYTE*)dict); } /* Dictionary format : @@ -4322,11 +5073,12 @@ size_t ZSTD_loadCEntropy(ZSTD_compressedBlockState_t* bs, void* workspace, * dictSize supposed >= 8 */ static size_t ZSTD_loadZstdDictionary(ZSTD_compressedBlockState_t* bs, - ZSTD_matchState_t* ms, + ZSTD_MatchState_t* ms, ZSTD_cwksp* ws, ZSTD_CCtx_params const* params, const void* dict, size_t dictSize, ZSTD_dictTableLoadMethod_e dtlm, + ZSTD_tableFillPurpose_e tfp, void* workspace) { const BYTE* dictPtr = (const BYTE*)dict; @@ -4345,7 +5097,7 @@ static size_t ZSTD_loadZstdDictionary(ZSTD_compressedBlockState_t* bs, { size_t const dictContentSize = (size_t)(dictEnd - dictPtr); FORWARD_IF_ERROR(ZSTD_loadDictionaryContent( - ms, NULL, ws, params, dictPtr, dictContentSize, dtlm), ""); + ms, NULL, ws, params, dictPtr, dictContentSize, dtlm, tfp), ""); } return dictID; } @@ -4354,13 +5106,14 @@ static size_t ZSTD_loadZstdDictionary(ZSTD_compressedBlockState_t* bs, * @return : dictID, or an error code */ static size_t ZSTD_compress_insertDictionary(ZSTD_compressedBlockState_t* bs, - ZSTD_matchState_t* ms, + ZSTD_MatchState_t* ms, ldmState_t* ls, ZSTD_cwksp* ws, const ZSTD_CCtx_params* params, const void* dict, size_t dictSize, ZSTD_dictContentType_e dictContentType, ZSTD_dictTableLoadMethod_e dtlm, + ZSTD_tableFillPurpose_e tfp, void* workspace) { DEBUGLOG(4, "ZSTD_compress_insertDictionary (dictSize=%u)", (U32)dictSize); @@ -4373,13 +5126,13 @@ ZSTD_compress_insertDictionary(ZSTD_compressedBlockState_t* bs, /* dict restricted modes */ if (dictContentType == ZSTD_dct_rawContent) - return ZSTD_loadDictionaryContent(ms, ls, ws, params, dict, dictSize, dtlm); + return ZSTD_loadDictionaryContent(ms, ls, ws, params, dict, dictSize, dtlm, tfp); if (MEM_readLE32(dict) != ZSTD_MAGIC_DICTIONARY) { if (dictContentType == ZSTD_dct_auto) { DEBUGLOG(4, "raw content dictionary detected"); return ZSTD_loadDictionaryContent( - ms, ls, ws, params, dict, dictSize, dtlm); + ms, ls, ws, params, dict, dictSize, dtlm, tfp); } RETURN_ERROR_IF(dictContentType == ZSTD_dct_fullDict, dictionary_wrong, ""); assert(0); /* impossible */ @@ -4387,13 +5140,14 @@ ZSTD_compress_insertDictionary(ZSTD_compressedBlockState_t* bs, /* dict as full zstd dictionary */ return ZSTD_loadZstdDictionary( - bs, ms, ws, params, dict, dictSize, dtlm, workspace); + bs, ms, ws, params, dict, dictSize, dtlm, tfp, workspace); } #define ZSTD_USE_CDICT_PARAMS_SRCSIZE_CUTOFF (128 KB) #define ZSTD_USE_CDICT_PARAMS_DICTSIZE_MULTIPLIER (6ULL) /*! ZSTD_compressBegin_internal() : + * Assumption : either @dict OR @cdict (or none) is non-NULL, never both * @return : 0, or an error code */ static size_t ZSTD_compressBegin_internal(ZSTD_CCtx* cctx, const void* dict, size_t dictSize, @@ -4426,11 +5180,11 @@ static size_t ZSTD_compressBegin_internal(ZSTD_CCtx* cctx, cctx->blockState.prevCBlock, &cctx->blockState.matchState, &cctx->ldmState, &cctx->workspace, &cctx->appliedParams, cdict->dictContent, cdict->dictContentSize, cdict->dictContentType, dtlm, - cctx->entropyWorkspace) + ZSTD_tfp_forCCtx, cctx->tmpWorkspace) : ZSTD_compress_insertDictionary( cctx->blockState.prevCBlock, &cctx->blockState.matchState, &cctx->ldmState, &cctx->workspace, &cctx->appliedParams, dict, dictSize, - dictContentType, dtlm, cctx->entropyWorkspace); + dictContentType, dtlm, ZSTD_tfp_forCCtx, cctx->tmpWorkspace); FORWARD_IF_ERROR(dictID, "ZSTD_compress_insertDictionary failed"); assert(dictID <= UINT_MAX); cctx->dictID = (U32)dictID; @@ -4471,11 +5225,11 @@ size_t ZSTD_compressBegin_advanced(ZSTD_CCtx* cctx, &cctxParams, pledgedSrcSize); } -size_t ZSTD_compressBegin_usingDict(ZSTD_CCtx* cctx, const void* dict, size_t dictSize, int compressionLevel) +static size_t +ZSTD_compressBegin_usingDict_deprecated(ZSTD_CCtx* cctx, const void* dict, size_t dictSize, int compressionLevel) { ZSTD_CCtx_params cctxParams; - { - ZSTD_parameters const params = ZSTD_getParams_internal(compressionLevel, ZSTD_CONTENTSIZE_UNKNOWN, dictSize, ZSTD_cpm_noAttachDict); + { ZSTD_parameters const params = ZSTD_getParams_internal(compressionLevel, ZSTD_CONTENTSIZE_UNKNOWN, dictSize, ZSTD_cpm_noAttachDict); ZSTD_CCtxParams_init_internal(&cctxParams, ¶ms, (compressionLevel == 0) ? ZSTD_CLEVEL_DEFAULT : compressionLevel); } DEBUGLOG(4, "ZSTD_compressBegin_usingDict (dictSize=%u)", (unsigned)dictSize); @@ -4483,9 +5237,15 @@ size_t ZSTD_compressBegin_usingDict(ZSTD_CCtx* cctx, const void* dict, size_t di &cctxParams, ZSTD_CONTENTSIZE_UNKNOWN, ZSTDb_not_buffered); } +size_t +ZSTD_compressBegin_usingDict(ZSTD_CCtx* cctx, const void* dict, size_t dictSize, int compressionLevel) +{ + return ZSTD_compressBegin_usingDict_deprecated(cctx, dict, dictSize, compressionLevel); +} + size_t ZSTD_compressBegin(ZSTD_CCtx* cctx, int compressionLevel) { - return ZSTD_compressBegin_usingDict(cctx, NULL, 0, compressionLevel); + return ZSTD_compressBegin_usingDict_deprecated(cctx, NULL, 0, compressionLevel); } @@ -4496,14 +5256,13 @@ static size_t ZSTD_writeEpilogue(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity) { BYTE* const ostart = (BYTE*)dst; BYTE* op = ostart; - size_t fhSize = 0; DEBUGLOG(4, "ZSTD_writeEpilogue"); RETURN_ERROR_IF(cctx->stage == ZSTDcs_created, stage_wrong, "init missing"); /* special case : empty frame */ if (cctx->stage == ZSTDcs_init) { - fhSize = ZSTD_writeFrameHeader(dst, dstCapacity, &cctx->appliedParams, 0, 0); + size_t fhSize = ZSTD_writeFrameHeader(dst, dstCapacity, &cctx->appliedParams, 0, 0); FORWARD_IF_ERROR(fhSize, "ZSTD_writeFrameHeader failed"); dstCapacity -= fhSize; op += fhSize; @@ -4513,8 +5272,9 @@ static size_t ZSTD_writeEpilogue(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity) if (cctx->stage != ZSTDcs_ending) { /* write one last empty block, make it the "last" block */ U32 const cBlockHeader24 = 1 /* last block */ + (((U32)bt_raw)<<1) + 0; - RETURN_ERROR_IF(dstCapacity<4, dstSize_tooSmall, "no room for epilogue"); - MEM_writeLE32(op, cBlockHeader24); + ZSTD_STATIC_ASSERT(ZSTD_BLOCKHEADERSIZE == 3); + RETURN_ERROR_IF(dstCapacity<3, dstSize_tooSmall, "no room for epilogue"); + MEM_writeLE24(op, cBlockHeader24); op += ZSTD_blockHeaderSize; dstCapacity -= ZSTD_blockHeaderSize; } @@ -4528,7 +5288,7 @@ static size_t ZSTD_writeEpilogue(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity) } cctx->stage = ZSTDcs_created; /* return to "created but no init" status */ - return op-ostart; + return (size_t)(op-ostart); } void ZSTD_CCtx_trace(ZSTD_CCtx* cctx, size_t extraCSize) @@ -4537,9 +5297,9 @@ void ZSTD_CCtx_trace(ZSTD_CCtx* cctx, size_t extraCSize) (void)extraCSize; } -size_t ZSTD_compressEnd (ZSTD_CCtx* cctx, - void* dst, size_t dstCapacity, - const void* src, size_t srcSize) +size_t ZSTD_compressEnd_public(ZSTD_CCtx* cctx, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize) { size_t endResult; size_t const cSize = ZSTD_compressContinue_internal(cctx, @@ -4563,6 +5323,14 @@ size_t ZSTD_compressEnd (ZSTD_CCtx* cctx, return cSize + endResult; } +/* NOTE: Must just wrap ZSTD_compressEnd_public() */ +size_t ZSTD_compressEnd(ZSTD_CCtx* cctx, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize) +{ + return ZSTD_compressEnd_public(cctx, dst, dstCapacity, src, srcSize); +} + size_t ZSTD_compress_advanced (ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize, @@ -4591,7 +5359,7 @@ size_t ZSTD_compress_advanced_internal( FORWARD_IF_ERROR( ZSTD_compressBegin_internal(cctx, dict, dictSize, ZSTD_dct_auto, ZSTD_dtlm_fast, NULL, params, srcSize, ZSTDb_not_buffered) , ""); - return ZSTD_compressEnd(cctx, dst, dstCapacity, src, srcSize); + return ZSTD_compressEnd_public(cctx, dst, dstCapacity, src, srcSize); } size_t ZSTD_compress_usingDict(ZSTD_CCtx* cctx, @@ -4709,7 +5477,7 @@ static size_t ZSTD_initCDict_internal( { size_t const dictID = ZSTD_compress_insertDictionary( &cdict->cBlockState, &cdict->matchState, NULL, &cdict->workspace, ¶ms, cdict->dictContent, cdict->dictContentSize, - dictContentType, ZSTD_dtlm_full, cdict->entropyWorkspace); + dictContentType, ZSTD_dtlm_full, ZSTD_tfp_forCDict, cdict->entropyWorkspace); FORWARD_IF_ERROR(dictID, "ZSTD_compress_insertDictionary failed"); assert(dictID <= (size_t)(U32)-1); cdict->dictID = (U32)dictID; @@ -4719,14 +5487,16 @@ static size_t ZSTD_initCDict_internal( return 0; } -static ZSTD_CDict* ZSTD_createCDict_advanced_internal(size_t dictSize, - ZSTD_dictLoadMethod_e dictLoadMethod, - ZSTD_compressionParameters cParams, - ZSTD_paramSwitch_e useRowMatchFinder, - U32 enableDedicatedDictSearch, - ZSTD_customMem customMem) +static ZSTD_CDict* +ZSTD_createCDict_advanced_internal(size_t dictSize, + ZSTD_dictLoadMethod_e dictLoadMethod, + ZSTD_compressionParameters cParams, + ZSTD_ParamSwitch_e useRowMatchFinder, + int enableDedicatedDictSearch, + ZSTD_customMem customMem) { if ((!customMem.customAlloc) ^ (!customMem.customFree)) return NULL; + DEBUGLOG(3, "ZSTD_createCDict_advanced_internal (dictSize=%u)", (unsigned)dictSize); { size_t const workspaceSize = ZSTD_cwksp_alloc_size(sizeof(ZSTD_CDict)) + @@ -4763,6 +5533,7 @@ ZSTD_CDict* ZSTD_createCDict_advanced(const void* dictBuffer, size_t dictSize, { ZSTD_CCtx_params cctxParams; ZSTD_memset(&cctxParams, 0, sizeof(cctxParams)); + DEBUGLOG(3, "ZSTD_createCDict_advanced, dictSize=%u, mode=%u", (unsigned)dictSize, (unsigned)dictContentType); ZSTD_CCtxParams_init(&cctxParams, 0); cctxParams.cParams = cParams; cctxParams.customMem = customMem; @@ -4783,7 +5554,7 @@ ZSTD_CDict* ZSTD_createCDict_advanced2( ZSTD_compressionParameters cParams; ZSTD_CDict* cdict; - DEBUGLOG(3, "ZSTD_createCDict_advanced2, mode %u", (unsigned)dictContentType); + DEBUGLOG(3, "ZSTD_createCDict_advanced2, dictSize=%u, mode=%u", (unsigned)dictSize, (unsigned)dictContentType); if (!customMem.customAlloc ^ !customMem.customFree) return NULL; if (cctxParams.enableDedicatedDictSearch) { @@ -4802,7 +5573,7 @@ ZSTD_CDict* ZSTD_createCDict_advanced2( &cctxParams, ZSTD_CONTENTSIZE_UNKNOWN, dictSize, ZSTD_cpm_createCDict); } - DEBUGLOG(3, "ZSTD_createCDict_advanced2: DDS: %u", cctxParams.enableDedicatedDictSearch); + DEBUGLOG(3, "ZSTD_createCDict_advanced2: DedicatedDictSearch=%u", cctxParams.enableDedicatedDictSearch); cctxParams.cParams = cParams; cctxParams.useRowMatchFinder = ZSTD_resolveRowMatchFinderMode(cctxParams.useRowMatchFinder, &cParams); @@ -4810,10 +5581,8 @@ ZSTD_CDict* ZSTD_createCDict_advanced2( dictLoadMethod, cctxParams.cParams, cctxParams.useRowMatchFinder, cctxParams.enableDedicatedDictSearch, customMem); - if (!cdict) - return NULL; - if (ZSTD_isError( ZSTD_initCDict_internal(cdict, + if (!cdict || ZSTD_isError( ZSTD_initCDict_internal(cdict, dict, dictSize, dictLoadMethod, dictContentType, cctxParams) )) { @@ -4867,7 +5636,7 @@ size_t ZSTD_freeCDict(ZSTD_CDict* cdict) * workspaceSize: Use ZSTD_estimateCDictSize() * to determine how large workspace must be. * cParams : use ZSTD_getCParams() to transform a compression level - * into its relevants cParams. + * into its relevant cParams. * @return : pointer to ZSTD_CDict*, or NULL if error (size too small) * Note : there is no corresponding "free" function. * Since workspace was allocated externally, it must be freed externally. @@ -4879,7 +5648,7 @@ const ZSTD_CDict* ZSTD_initStaticCDict( ZSTD_dictContentType_e dictContentType, ZSTD_compressionParameters cParams) { - ZSTD_paramSwitch_e const useRowMatchFinder = ZSTD_resolveRowMatchFinderMode(ZSTD_ps_auto, &cParams); + ZSTD_ParamSwitch_e const useRowMatchFinder = ZSTD_resolveRowMatchFinderMode(ZSTD_ps_auto, &cParams); /* enableDedicatedDictSearch == 1 ensures matchstate is not too small in case this CDict will be used for DDS + row hash */ size_t const matchStateSize = ZSTD_sizeof_matchState(&cParams, useRowMatchFinder, /* enableDedicatedDictSearch */ 1, /* forCCtx */ 0); size_t const neededSize = ZSTD_cwksp_alloc_size(sizeof(ZSTD_CDict)) @@ -4890,6 +5659,7 @@ const ZSTD_CDict* ZSTD_initStaticCDict( ZSTD_CDict* cdict; ZSTD_CCtx_params params; + DEBUGLOG(4, "ZSTD_initStaticCDict (dictSize==%u)", (unsigned)dictSize); if ((size_t)workspace & 7) return NULL; /* 8-aligned */ { @@ -4900,14 +5670,13 @@ const ZSTD_CDict* ZSTD_initStaticCDict( ZSTD_cwksp_move(&cdict->workspace, &ws); } - DEBUGLOG(4, "(workspaceSize < neededSize) : (%u < %u) => %u", - (unsigned)workspaceSize, (unsigned)neededSize, (unsigned)(workspaceSize < neededSize)); if (workspaceSize < neededSize) return NULL; ZSTD_CCtxParams_init(¶ms, 0); params.cParams = cParams; params.useRowMatchFinder = useRowMatchFinder; cdict->useRowMatchFinder = useRowMatchFinder; + cdict->compressionLevel = ZSTD_NO_CLEVEL; if (ZSTD_isError( ZSTD_initCDict_internal(cdict, dict, dictSize, @@ -4987,12 +5756,17 @@ size_t ZSTD_compressBegin_usingCDict_advanced( /* ZSTD_compressBegin_usingCDict() : * cdict must be != NULL */ -size_t ZSTD_compressBegin_usingCDict(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict) +size_t ZSTD_compressBegin_usingCDict_deprecated(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict) { ZSTD_frameParameters const fParams = { 0 /*content*/, 0 /*checksum*/, 0 /*noDictID*/ }; return ZSTD_compressBegin_usingCDict_internal(cctx, cdict, fParams, ZSTD_CONTENTSIZE_UNKNOWN); } +size_t ZSTD_compressBegin_usingCDict(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict) +{ + return ZSTD_compressBegin_usingCDict_deprecated(cctx, cdict); +} + /*! ZSTD_compress_usingCDict_internal(): * Implementation of various ZSTD_compress_usingCDict* functions. */ @@ -5002,7 +5776,7 @@ static size_t ZSTD_compress_usingCDict_internal(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict, ZSTD_frameParameters fParams) { FORWARD_IF_ERROR(ZSTD_compressBegin_usingCDict_internal(cctx, cdict, fParams, srcSize), ""); /* will check if cdict != NULL */ - return ZSTD_compressEnd(cctx, dst, dstCapacity, src, srcSize); + return ZSTD_compressEnd_public(cctx, dst, dstCapacity, src, srcSize); } /*! ZSTD_compress_usingCDict_advanced(): @@ -5068,7 +5842,7 @@ size_t ZSTD_CStreamOutSize(void) return ZSTD_compressBound(ZSTD_BLOCKSIZE_MAX) + ZSTD_blockHeaderSize + 4 /* 32-bits hash */ ; } -static ZSTD_cParamMode_e ZSTD_getCParamMode(ZSTD_CDict const* cdict, ZSTD_CCtx_params const* params, U64 pledgedSrcSize) +static ZSTD_CParamMode_e ZSTD_getCParamMode(ZSTD_CDict const* cdict, ZSTD_CCtx_params const* params, U64 pledgedSrcSize) { if (cdict != NULL && ZSTD_shouldAttachDict(cdict, params, pledgedSrcSize)) return ZSTD_cpm_attachDict; @@ -5199,30 +5973,41 @@ size_t ZSTD_initCStream(ZSTD_CStream* zcs, int compressionLevel) static size_t ZSTD_nextInputSizeHint(const ZSTD_CCtx* cctx) { - size_t hintInSize = cctx->inBuffTarget - cctx->inBuffPos; - if (hintInSize==0) hintInSize = cctx->blockSize; - return hintInSize; + if (cctx->appliedParams.inBufferMode == ZSTD_bm_stable) { + return cctx->blockSizeMax - cctx->stableIn_notConsumed; + } + assert(cctx->appliedParams.inBufferMode == ZSTD_bm_buffered); + { size_t hintInSize = cctx->inBuffTarget - cctx->inBuffPos; + if (hintInSize==0) hintInSize = cctx->blockSizeMax; + return hintInSize; + } } /* ZSTD_compressStream_generic(): * internal function for all *compressStream*() variants - * non-static, because can be called from zstdmt_compress.c - * @return : hint size for next input */ + * @return : hint size for next input to complete ongoing block */ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, ZSTD_outBuffer* output, ZSTD_inBuffer* input, ZSTD_EndDirective const flushMode) { - const char* const istart = (const char*)input->src; - const char* const iend = input->size != 0 ? istart + input->size : istart; - const char* ip = input->pos != 0 ? istart + input->pos : istart; - char* const ostart = (char*)output->dst; - char* const oend = output->size != 0 ? ostart + output->size : ostart; - char* op = output->pos != 0 ? ostart + output->pos : ostart; + const char* const istart = (assert(input != NULL), (const char*)input->src); + const char* const iend = (istart != NULL) ? istart + input->size : istart; + const char* ip = (istart != NULL) ? istart + input->pos : istart; + char* const ostart = (assert(output != NULL), (char*)output->dst); + char* const oend = (ostart != NULL) ? ostart + output->size : ostart; + char* op = (ostart != NULL) ? ostart + output->pos : ostart; U32 someMoreWork = 1; /* check expectations */ - DEBUGLOG(5, "ZSTD_compressStream_generic, flush=%u", (unsigned)flushMode); + DEBUGLOG(5, "ZSTD_compressStream_generic, flush=%i, srcSize = %zu", (int)flushMode, input->size - input->pos); + assert(zcs != NULL); + if (zcs->appliedParams.inBufferMode == ZSTD_bm_stable) { + assert(input->pos >= zcs->stableIn_notConsumed); + input->pos -= zcs->stableIn_notConsumed; + if (ip) ip -= zcs->stableIn_notConsumed; + zcs->stableIn_notConsumed = 0; + } if (zcs->appliedParams.inBufferMode == ZSTD_bm_buffered) { assert(zcs->inBuff != NULL); assert(zcs->inBuffSize > 0); @@ -5231,8 +6016,10 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, assert(zcs->outBuff != NULL); assert(zcs->outBuffSize > 0); } - assert(output->pos <= output->size); + if (input->src == NULL) assert(input->size == 0); assert(input->pos <= input->size); + if (output->dst == NULL) assert(output->size == 0); + assert(output->pos <= output->size); assert((U32)flushMode <= (U32)ZSTD_e_end); while (someMoreWork) { @@ -5243,12 +6030,13 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, case zcss_load: if ( (flushMode == ZSTD_e_end) - && ( (size_t)(oend-op) >= ZSTD_compressBound(iend-ip) /* Enough output space */ + && ( (size_t)(oend-op) >= ZSTD_compressBound((size_t)(iend-ip)) /* Enough output space */ || zcs->appliedParams.outBufferMode == ZSTD_bm_stable) /* OR we are allowed to return dstSizeTooSmall */ && (zcs->inBuffPos == 0) ) { /* shortcut to compression pass directly into output buffer */ - size_t const cSize = ZSTD_compressEnd(zcs, - op, oend-op, ip, iend-ip); + size_t const cSize = ZSTD_compressEnd_public(zcs, + op, (size_t)(oend-op), + ip, (size_t)(iend-ip)); DEBUGLOG(4, "ZSTD_compressEnd : cSize=%u", (unsigned)cSize); FORWARD_IF_ERROR(cSize, "ZSTD_compressEnd failed"); ip = iend; @@ -5262,10 +6050,9 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, size_t const toLoad = zcs->inBuffTarget - zcs->inBuffPos; size_t const loaded = ZSTD_limitCopy( zcs->inBuff + zcs->inBuffPos, toLoad, - ip, iend-ip); + ip, (size_t)(iend-ip)); zcs->inBuffPos += loaded; - if (loaded != 0) - ip += loaded; + if (ip) ip += loaded; if ( (flushMode == ZSTD_e_continue) && (zcs->inBuffPos < zcs->inBuffTarget) ) { /* not enough input to fill full block : stop here */ @@ -5276,16 +6063,29 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, /* empty */ someMoreWork = 0; break; } + } else { + assert(zcs->appliedParams.inBufferMode == ZSTD_bm_stable); + if ( (flushMode == ZSTD_e_continue) + && ( (size_t)(iend - ip) < zcs->blockSizeMax) ) { + /* can't compress a full block : stop here */ + zcs->stableIn_notConsumed = (size_t)(iend - ip); + ip = iend; /* pretend to have consumed input */ + someMoreWork = 0; break; + } + if ( (flushMode == ZSTD_e_flush) + && (ip == iend) ) { + /* empty */ + someMoreWork = 0; break; + } } /* compress current block (note : this stage cannot be stopped in the middle) */ DEBUGLOG(5, "stream compression stage (flushMode==%u)", flushMode); { int const inputBuffered = (zcs->appliedParams.inBufferMode == ZSTD_bm_buffered); void* cDst; size_t cSize; - size_t oSize = oend-op; - size_t const iSize = inputBuffered - ? zcs->inBuffPos - zcs->inToCompress - : MIN((size_t)(iend - ip), zcs->blockSize); + size_t oSize = (size_t)(oend-op); + size_t const iSize = inputBuffered ? zcs->inBuffPos - zcs->inToCompress + : MIN((size_t)(iend - ip), zcs->blockSizeMax); if (oSize >= ZSTD_compressBound(iSize) || zcs->appliedParams.outBufferMode == ZSTD_bm_stable) cDst = op; /* compress into output buffer, to skip flush stage */ else @@ -5293,34 +6093,31 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, if (inputBuffered) { unsigned const lastBlock = (flushMode == ZSTD_e_end) && (ip==iend); cSize = lastBlock ? - ZSTD_compressEnd(zcs, cDst, oSize, + ZSTD_compressEnd_public(zcs, cDst, oSize, zcs->inBuff + zcs->inToCompress, iSize) : - ZSTD_compressContinue(zcs, cDst, oSize, + ZSTD_compressContinue_public(zcs, cDst, oSize, zcs->inBuff + zcs->inToCompress, iSize); FORWARD_IF_ERROR(cSize, "%s", lastBlock ? "ZSTD_compressEnd failed" : "ZSTD_compressContinue failed"); zcs->frameEnded = lastBlock; /* prepare next block */ - zcs->inBuffTarget = zcs->inBuffPos + zcs->blockSize; + zcs->inBuffTarget = zcs->inBuffPos + zcs->blockSizeMax; if (zcs->inBuffTarget > zcs->inBuffSize) - zcs->inBuffPos = 0, zcs->inBuffTarget = zcs->blockSize; + zcs->inBuffPos = 0, zcs->inBuffTarget = zcs->blockSizeMax; DEBUGLOG(5, "inBuffTarget:%u / inBuffSize:%u", (unsigned)zcs->inBuffTarget, (unsigned)zcs->inBuffSize); if (!lastBlock) assert(zcs->inBuffTarget <= zcs->inBuffSize); zcs->inToCompress = zcs->inBuffPos; - } else { - unsigned const lastBlock = (ip + iSize == iend); - assert(flushMode == ZSTD_e_end /* Already validated */); + } else { /* !inputBuffered, hence ZSTD_bm_stable */ + unsigned const lastBlock = (flushMode == ZSTD_e_end) && (ip + iSize == iend); cSize = lastBlock ? - ZSTD_compressEnd(zcs, cDst, oSize, ip, iSize) : - ZSTD_compressContinue(zcs, cDst, oSize, ip, iSize); + ZSTD_compressEnd_public(zcs, cDst, oSize, ip, iSize) : + ZSTD_compressContinue_public(zcs, cDst, oSize, ip, iSize); /* Consume the input prior to error checking to mirror buffered mode. */ - if (iSize > 0) - ip += iSize; + if (ip) ip += iSize; FORWARD_IF_ERROR(cSize, "%s", lastBlock ? "ZSTD_compressEnd failed" : "ZSTD_compressContinue failed"); zcs->frameEnded = lastBlock; - if (lastBlock) - assert(ip == iend); + if (lastBlock) assert(ip == iend); } if (cDst == op) { /* no need to flush */ op += cSize; @@ -5369,8 +6166,8 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, } } - input->pos = ip - istart; - output->pos = op - ostart; + input->pos = (size_t)(ip - istart); + output->pos = (size_t)(op - ostart); if (zcs->frameEnded) return 0; return ZSTD_nextInputSizeHint(zcs); } @@ -5390,8 +6187,10 @@ size_t ZSTD_compressStream(ZSTD_CStream* zcs, ZSTD_outBuffer* output, ZSTD_inBuf /* After a compression call set the expected input/output buffer. * This is validated at the start of the next compression call. */ -static void ZSTD_setBufferExpectations(ZSTD_CCtx* cctx, ZSTD_outBuffer const* output, ZSTD_inBuffer const* input) +static void +ZSTD_setBufferExpectations(ZSTD_CCtx* cctx, const ZSTD_outBuffer* output, const ZSTD_inBuffer* input) { + DEBUGLOG(5, "ZSTD_setBufferExpectations (for advanced stable in/out modes)"); if (cctx->appliedParams.inBufferMode == ZSTD_bm_stable) { cctx->expectedInBuffer = *input; } @@ -5410,22 +6209,27 @@ static size_t ZSTD_checkBufferStability(ZSTD_CCtx const* cctx, { if (cctx->appliedParams.inBufferMode == ZSTD_bm_stable) { ZSTD_inBuffer const expect = cctx->expectedInBuffer; - if (expect.src != input->src || expect.pos != input->pos || expect.size != input->size) - RETURN_ERROR(srcBuffer_wrong, "ZSTD_c_stableInBuffer enabled but input differs!"); - if (endOp != ZSTD_e_end) - RETURN_ERROR(srcBuffer_wrong, "ZSTD_c_stableInBuffer can only be used with ZSTD_e_end!"); + if (expect.src != input->src || expect.pos != input->pos) + RETURN_ERROR(stabilityCondition_notRespected, "ZSTD_c_stableInBuffer enabled but input differs!"); } + (void)endOp; if (cctx->appliedParams.outBufferMode == ZSTD_bm_stable) { size_t const outBufferSize = output->size - output->pos; if (cctx->expectedOutBufferSize != outBufferSize) - RETURN_ERROR(dstBuffer_wrong, "ZSTD_c_stableOutBuffer enabled but output size differs!"); + RETURN_ERROR(stabilityCondition_notRespected, "ZSTD_c_stableOutBuffer enabled but output size differs!"); } return 0; } +/* + * If @endOp == ZSTD_e_end, @inSize becomes pledgedSrcSize. + * Otherwise, it's ignored. + * @return: 0 on success, or a ZSTD_error code otherwise. + */ static size_t ZSTD_CCtx_init_compressStream2(ZSTD_CCtx* cctx, ZSTD_EndDirective endOp, - size_t inSize) { + size_t inSize) +{ ZSTD_CCtx_params params = cctx->requestedParams; ZSTD_prefixDict const prefixDict = cctx->prefixDict; FORWARD_IF_ERROR( ZSTD_initLocalDict(cctx) , ""); /* Init the local dict if present. */ @@ -5438,21 +6242,24 @@ static size_t ZSTD_CCtx_init_compressStream2(ZSTD_CCtx* cctx, */ params.compressionLevel = cctx->cdict->compressionLevel; } - DEBUGLOG(4, "ZSTD_compressStream2 : transparent init stage"); - if (endOp == ZSTD_e_end) cctx->pledgedSrcSizePlusOne = inSize + 1; /* auto-fix pledgedSrcSize */ - { - size_t const dictSize = prefixDict.dict + DEBUGLOG(4, "ZSTD_CCtx_init_compressStream2 : transparent init stage"); + if (endOp == ZSTD_e_end) cctx->pledgedSrcSizePlusOne = inSize + 1; /* auto-determine pledgedSrcSize */ + + { size_t const dictSize = prefixDict.dict ? prefixDict.dictSize : (cctx->cdict ? cctx->cdict->dictContentSize : 0); - ZSTD_cParamMode_e const mode = ZSTD_getCParamMode(cctx->cdict, ¶ms, cctx->pledgedSrcSizePlusOne - 1); + ZSTD_CParamMode_e const mode = ZSTD_getCParamMode(cctx->cdict, ¶ms, cctx->pledgedSrcSizePlusOne - 1); params.cParams = ZSTD_getCParamsFromCCtxParams( ¶ms, cctx->pledgedSrcSizePlusOne-1, dictSize, mode); } - params.useBlockSplitter = ZSTD_resolveBlockSplitterMode(params.useBlockSplitter, ¶ms.cParams); + params.postBlockSplitter = ZSTD_resolveBlockSplitterMode(params.postBlockSplitter, ¶ms.cParams); params.ldmParams.enableLdm = ZSTD_resolveEnableLdm(params.ldmParams.enableLdm, ¶ms.cParams); params.useRowMatchFinder = ZSTD_resolveRowMatchFinderMode(params.useRowMatchFinder, ¶ms.cParams); + params.validateSequences = ZSTD_resolveExternalSequenceValidation(params.validateSequences); + params.maxBlockSize = ZSTD_resolveMaxBlockSize(params.maxBlockSize); + params.searchForExternalRepcodes = ZSTD_resolveExternalRepcodeSearch(params.searchForExternalRepcodes, params.compressionLevel); { U64 const pledgedSrcSize = cctx->pledgedSrcSizePlusOne - 1; assert(!ZSTD_isError(ZSTD_checkCParams(params.cParams))); @@ -5468,7 +6275,7 @@ static size_t ZSTD_CCtx_init_compressStream2(ZSTD_CCtx* cctx, /* for small input: avoid automatic flush on reaching end of block, since * it would require to add a 3-bytes null block to end frame */ - cctx->inBuffTarget = cctx->blockSize + (cctx->blockSize == pledgedSrcSize); + cctx->inBuffTarget = cctx->blockSizeMax + (cctx->blockSizeMax == pledgedSrcSize); } else { cctx->inBuffTarget = 0; } @@ -5479,6 +6286,8 @@ static size_t ZSTD_CCtx_init_compressStream2(ZSTD_CCtx* cctx, return 0; } +/* @return provides a minimum amount of data remaining to be flushed from internal buffers + */ size_t ZSTD_compressStream2( ZSTD_CCtx* cctx, ZSTD_outBuffer* output, ZSTD_inBuffer* input, @@ -5493,8 +6302,27 @@ size_t ZSTD_compressStream2( ZSTD_CCtx* cctx, /* transparent initialization stage */ if (cctx->streamStage == zcss_init) { - FORWARD_IF_ERROR(ZSTD_CCtx_init_compressStream2(cctx, endOp, input->size), "CompressStream2 initialization failed"); - ZSTD_setBufferExpectations(cctx, output, input); /* Set initial buffer expectations now that we've initialized */ + size_t const inputSize = input->size - input->pos; /* no obligation to start from pos==0 */ + size_t const totalInputSize = inputSize + cctx->stableIn_notConsumed; + if ( (cctx->requestedParams.inBufferMode == ZSTD_bm_stable) /* input is presumed stable, across invocations */ + && (endOp == ZSTD_e_continue) /* no flush requested, more input to come */ + && (totalInputSize < ZSTD_BLOCKSIZE_MAX) ) { /* not even reached one block yet */ + if (cctx->stableIn_notConsumed) { /* not the first time */ + /* check stable source guarantees */ + RETURN_ERROR_IF(input->src != cctx->expectedInBuffer.src, stabilityCondition_notRespected, "stableInBuffer condition not respected: wrong src pointer"); + RETURN_ERROR_IF(input->pos != cctx->expectedInBuffer.size, stabilityCondition_notRespected, "stableInBuffer condition not respected: externally modified pos"); + } + /* pretend input was consumed, to give a sense forward progress */ + input->pos = input->size; + /* save stable inBuffer, for later control, and flush/end */ + cctx->expectedInBuffer = *input; + /* but actually input wasn't consumed, so keep track of position from where compression shall resume */ + cctx->stableIn_notConsumed += inputSize; + /* don't initialize yet, wait for the first block of flush() order, for better parameters adaptation */ + return ZSTD_FRAMEHEADERSIZE_MIN(cctx->requestedParams.format); /* at least some header to produce */ + } + FORWARD_IF_ERROR(ZSTD_CCtx_init_compressStream2(cctx, endOp, totalInputSize), "compressStream2 initialization failed"); + ZSTD_setBufferExpectations(cctx, output, input); /* Set initial buffer expectations now that we've initialized */ } /* end of transparent initialization stage */ @@ -5512,13 +6340,20 @@ size_t ZSTD_compressStream2_simpleArgs ( const void* src, size_t srcSize, size_t* srcPos, ZSTD_EndDirective endOp) { - ZSTD_outBuffer output = { dst, dstCapacity, *dstPos }; - ZSTD_inBuffer input = { src, srcSize, *srcPos }; + ZSTD_outBuffer output; + ZSTD_inBuffer input; + output.dst = dst; + output.size = dstCapacity; + output.pos = *dstPos; + input.src = src; + input.size = srcSize; + input.pos = *srcPos; /* ZSTD_compressStream2() will check validity of dstPos and srcPos */ - size_t const cErr = ZSTD_compressStream2(cctx, &output, &input, endOp); - *dstPos = output.pos; - *srcPos = input.pos; - return cErr; + { size_t const cErr = ZSTD_compressStream2(cctx, &output, &input, endOp); + *dstPos = output.pos; + *srcPos = input.pos; + return cErr; + } } size_t ZSTD_compress2(ZSTD_CCtx* cctx, @@ -5541,6 +6376,7 @@ size_t ZSTD_compress2(ZSTD_CCtx* cctx, /* Reset to the original values. */ cctx->requestedParams.inBufferMode = originalInBufferMode; cctx->requestedParams.outBufferMode = originalOutBufferMode; + FORWARD_IF_ERROR(result, "ZSTD_compressStream2_simpleArgs failed"); if (result != 0) { /* compression not completed, due to lack of output space */ assert(oPos == dstCapacity); @@ -5551,64 +6387,67 @@ size_t ZSTD_compress2(ZSTD_CCtx* cctx, } } -typedef struct { - U32 idx; /* Index in array of ZSTD_Sequence */ - U32 posInSequence; /* Position within sequence at idx */ - size_t posInSrc; /* Number of bytes given by sequences provided so far */ -} ZSTD_sequencePosition; - /* ZSTD_validateSequence() : - * @offCode : is presumed to follow format required by ZSTD_storeSeq() + * @offBase : must use the format required by ZSTD_storeSeq() * @returns a ZSTD error code if sequence is not valid */ static size_t -ZSTD_validateSequence(U32 offCode, U32 matchLength, - size_t posInSrc, U32 windowLog, size_t dictSize) +ZSTD_validateSequence(U32 offBase, U32 matchLength, U32 minMatch, + size_t posInSrc, U32 windowLog, size_t dictSize, int useSequenceProducer) { - U32 const windowSize = 1 << windowLog; + U32 const windowSize = 1u << windowLog; /* posInSrc represents the amount of data the decoder would decode up to this point. * As long as the amount of data decoded is less than or equal to window size, offsets may be * larger than the total length of output decoded in order to reference the dict, even larger than * window size. After output surpasses windowSize, we're limited to windowSize offsets again. */ size_t const offsetBound = posInSrc > windowSize ? (size_t)windowSize : posInSrc + (size_t)dictSize; - RETURN_ERROR_IF(offCode > STORE_OFFSET(offsetBound), corruption_detected, "Offset too large!"); - RETURN_ERROR_IF(matchLength < MINMATCH, corruption_detected, "Matchlength too small"); + size_t const matchLenLowerBound = (minMatch == 3 || useSequenceProducer) ? 3 : 4; + RETURN_ERROR_IF(offBase > OFFSET_TO_OFFBASE(offsetBound), externalSequences_invalid, "Offset too large!"); + /* Validate maxNbSeq is large enough for the given matchLength and minMatch */ + RETURN_ERROR_IF(matchLength < matchLenLowerBound, externalSequences_invalid, "Matchlength too small for the minMatch"); return 0; } /* Returns an offset code, given a sequence's raw offset, the ongoing repcode array, and whether litLength == 0 */ -static U32 ZSTD_finalizeOffCode(U32 rawOffset, const U32 rep[ZSTD_REP_NUM], U32 ll0) +static U32 ZSTD_finalizeOffBase(U32 rawOffset, const U32 rep[ZSTD_REP_NUM], U32 ll0) { - U32 offCode = STORE_OFFSET(rawOffset); + U32 offBase = OFFSET_TO_OFFBASE(rawOffset); if (!ll0 && rawOffset == rep[0]) { - offCode = STORE_REPCODE_1; + offBase = REPCODE1_TO_OFFBASE; } else if (rawOffset == rep[1]) { - offCode = STORE_REPCODE(2 - ll0); + offBase = REPCODE_TO_OFFBASE(2 - ll0); } else if (rawOffset == rep[2]) { - offCode = STORE_REPCODE(3 - ll0); + offBase = REPCODE_TO_OFFBASE(3 - ll0); } else if (ll0 && rawOffset == rep[0] - 1) { - offCode = STORE_REPCODE_3; + offBase = REPCODE3_TO_OFFBASE; } - return offCode; + return offBase; } -/* Returns 0 on success, and a ZSTD_error otherwise. This function scans through an array of - * ZSTD_Sequence, storing the sequences it finds, until it reaches a block delimiter. +/* This function scans through an array of ZSTD_Sequence, + * storing the sequences it reads, until it reaches a block delimiter. + * Note that the block delimiter includes the last literals of the block. + * @blockSize must be == sum(sequence_lengths). + * @returns @blockSize on success, and a ZSTD_error otherwise. */ static size_t -ZSTD_copySequencesToSeqStoreExplicitBlockDelim(ZSTD_CCtx* cctx, - ZSTD_sequencePosition* seqPos, - const ZSTD_Sequence* const inSeqs, size_t inSeqsSize, - const void* src, size_t blockSize) +ZSTD_transferSequences_wBlockDelim(ZSTD_CCtx* cctx, + ZSTD_SequencePosition* seqPos, + const ZSTD_Sequence* const inSeqs, size_t inSeqsSize, + const void* src, size_t blockSize, + ZSTD_ParamSwitch_e externalRepSearch) { U32 idx = seqPos->idx; + U32 const startIdx = idx; BYTE const* ip = (BYTE const*)(src); const BYTE* const iend = ip + blockSize; - repcodes_t updatedRepcodes; + Repcodes_t updatedRepcodes; U32 dictSize; + DEBUGLOG(5, "ZSTD_transferSequences_wBlockDelim (blockSize = %zu)", blockSize); + if (cctx->cdict) { dictSize = (U32)cctx->cdict->dictContentSize; } else if (cctx->prefixDict.dict) { @@ -5616,27 +6455,60 @@ ZSTD_copySequencesToSeqStoreExplicitBlockDelim(ZSTD_CCtx* cctx, } else { dictSize = 0; } - ZSTD_memcpy(updatedRepcodes.rep, cctx->blockState.prevCBlock->rep, sizeof(repcodes_t)); - for (; (inSeqs[idx].matchLength != 0 || inSeqs[idx].offset != 0) && idx < inSeqsSize; ++idx) { + ZSTD_memcpy(updatedRepcodes.rep, cctx->blockState.prevCBlock->rep, sizeof(Repcodes_t)); + for (; idx < inSeqsSize && (inSeqs[idx].matchLength != 0 || inSeqs[idx].offset != 0); ++idx) { U32 const litLength = inSeqs[idx].litLength; - U32 const ll0 = (litLength == 0); U32 const matchLength = inSeqs[idx].matchLength; - U32 const offCode = ZSTD_finalizeOffCode(inSeqs[idx].offset, updatedRepcodes.rep, ll0); - ZSTD_updateRep(updatedRepcodes.rep, offCode, ll0); + U32 offBase; + + if (externalRepSearch == ZSTD_ps_disable) { + offBase = OFFSET_TO_OFFBASE(inSeqs[idx].offset); + } else { + U32 const ll0 = (litLength == 0); + offBase = ZSTD_finalizeOffBase(inSeqs[idx].offset, updatedRepcodes.rep, ll0); + ZSTD_updateRep(updatedRepcodes.rep, offBase, ll0); + } - DEBUGLOG(6, "Storing sequence: (of: %u, ml: %u, ll: %u)", offCode, matchLength, litLength); + DEBUGLOG(6, "Storing sequence: (of: %u, ml: %u, ll: %u)", offBase, matchLength, litLength); if (cctx->appliedParams.validateSequences) { seqPos->posInSrc += litLength + matchLength; - FORWARD_IF_ERROR(ZSTD_validateSequence(offCode, matchLength, seqPos->posInSrc, - cctx->appliedParams.cParams.windowLog, dictSize), + FORWARD_IF_ERROR(ZSTD_validateSequence(offBase, matchLength, cctx->appliedParams.cParams.minMatch, + seqPos->posInSrc, + cctx->appliedParams.cParams.windowLog, dictSize, + ZSTD_hasExtSeqProd(&cctx->appliedParams)), "Sequence validation failed"); } - RETURN_ERROR_IF(idx - seqPos->idx > cctx->seqStore.maxNbSeq, memory_allocation, + RETURN_ERROR_IF(idx - seqPos->idx >= cctx->seqStore.maxNbSeq, externalSequences_invalid, "Not enough memory allocated. Try adjusting ZSTD_c_minMatch."); - ZSTD_storeSeq(&cctx->seqStore, litLength, ip, iend, offCode, matchLength); + ZSTD_storeSeq(&cctx->seqStore, litLength, ip, iend, offBase, matchLength); ip += matchLength + litLength; } - ZSTD_memcpy(cctx->blockState.nextCBlock->rep, updatedRepcodes.rep, sizeof(repcodes_t)); + RETURN_ERROR_IF(idx == inSeqsSize, externalSequences_invalid, "Block delimiter not found."); + + /* If we skipped repcode search while parsing, we need to update repcodes now */ + assert(externalRepSearch != ZSTD_ps_auto); + assert(idx >= startIdx); + if (externalRepSearch == ZSTD_ps_disable && idx != startIdx) { + U32* const rep = updatedRepcodes.rep; + U32 lastSeqIdx = idx - 1; /* index of last non-block-delimiter sequence */ + + if (lastSeqIdx >= startIdx + 2) { + rep[2] = inSeqs[lastSeqIdx - 2].offset; + rep[1] = inSeqs[lastSeqIdx - 1].offset; + rep[0] = inSeqs[lastSeqIdx].offset; + } else if (lastSeqIdx == startIdx + 1) { + rep[2] = rep[0]; + rep[1] = inSeqs[lastSeqIdx - 1].offset; + rep[0] = inSeqs[lastSeqIdx].offset; + } else { + assert(lastSeqIdx == startIdx); + rep[2] = rep[1]; + rep[1] = rep[0]; + rep[0] = inSeqs[lastSeqIdx].offset; + } + } + + ZSTD_memcpy(cctx->blockState.nextCBlock->rep, updatedRepcodes.rep, sizeof(Repcodes_t)); if (inSeqs[idx].litLength) { DEBUGLOG(6, "Storing last literals of size: %u", inSeqs[idx].litLength); @@ -5644,37 +6516,43 @@ ZSTD_copySequencesToSeqStoreExplicitBlockDelim(ZSTD_CCtx* cctx, ip += inSeqs[idx].litLength; seqPos->posInSrc += inSeqs[idx].litLength; } - RETURN_ERROR_IF(ip != iend, corruption_detected, "Blocksize doesn't agree with block delimiter!"); + RETURN_ERROR_IF(ip != iend, externalSequences_invalid, "Blocksize doesn't agree with block delimiter!"); seqPos->idx = idx+1; - return 0; + return blockSize; } -/* Returns the number of bytes to move the current read position back by. Only non-zero - * if we ended up splitting a sequence. Otherwise, it may return a ZSTD error if something - * went wrong. +/* + * This function attempts to scan through @blockSize bytes in @src + * represented by the sequences in @inSeqs, + * storing any (partial) sequences. * - * This function will attempt to scan through blockSize bytes represented by the sequences - * in inSeqs, storing any (partial) sequences. + * Occasionally, we may want to reduce the actual number of bytes consumed from @src + * to avoid splitting a match, notably if it would produce a match smaller than MINMATCH. * - * Occasionally, we may want to change the actual number of bytes we consumed from inSeqs to - * avoid splitting a match, or to avoid splitting a match such that it would produce a match - * smaller than MINMATCH. In this case, we return the number of bytes that we didn't read from this block. + * @returns the number of bytes consumed from @src, necessarily <= @blockSize. + * Otherwise, it may return a ZSTD error if something went wrong. */ static size_t -ZSTD_copySequencesToSeqStoreNoBlockDelim(ZSTD_CCtx* cctx, ZSTD_sequencePosition* seqPos, - const ZSTD_Sequence* const inSeqs, size_t inSeqsSize, - const void* src, size_t blockSize) +ZSTD_transferSequences_noDelim(ZSTD_CCtx* cctx, + ZSTD_SequencePosition* seqPos, + const ZSTD_Sequence* const inSeqs, size_t inSeqsSize, + const void* src, size_t blockSize, + ZSTD_ParamSwitch_e externalRepSearch) { U32 idx = seqPos->idx; U32 startPosInSequence = seqPos->posInSequence; U32 endPosInSequence = seqPos->posInSequence + (U32)blockSize; size_t dictSize; - BYTE const* ip = (BYTE const*)(src); - BYTE const* iend = ip + blockSize; /* May be adjusted if we decide to process fewer than blockSize bytes */ - repcodes_t updatedRepcodes; + const BYTE* const istart = (const BYTE*)(src); + const BYTE* ip = istart; + const BYTE* iend = istart + blockSize; /* May be adjusted if we decide to process fewer than blockSize bytes */ + Repcodes_t updatedRepcodes; U32 bytesAdjustment = 0; U32 finalMatchSplit = 0; + /* TODO(embg) support fast parsing mode in noBlockDelim mode */ + (void)externalRepSearch; + if (cctx->cdict) { dictSize = cctx->cdict->dictContentSize; } else if (cctx->prefixDict.dict) { @@ -5682,15 +6560,15 @@ ZSTD_copySequencesToSeqStoreNoBlockDelim(ZSTD_CCtx* cctx, ZSTD_sequencePosition* } else { dictSize = 0; } - DEBUGLOG(5, "ZSTD_copySequencesToSeqStore: idx: %u PIS: %u blockSize: %zu", idx, startPosInSequence, blockSize); + DEBUGLOG(5, "ZSTD_transferSequences_noDelim: idx: %u PIS: %u blockSize: %zu", idx, startPosInSequence, blockSize); DEBUGLOG(5, "Start seq: idx: %u (of: %u ml: %u ll: %u)", idx, inSeqs[idx].offset, inSeqs[idx].matchLength, inSeqs[idx].litLength); - ZSTD_memcpy(updatedRepcodes.rep, cctx->blockState.prevCBlock->rep, sizeof(repcodes_t)); + ZSTD_memcpy(updatedRepcodes.rep, cctx->blockState.prevCBlock->rep, sizeof(Repcodes_t)); while (endPosInSequence && idx < inSeqsSize && !finalMatchSplit) { const ZSTD_Sequence currSeq = inSeqs[idx]; U32 litLength = currSeq.litLength; U32 matchLength = currSeq.matchLength; U32 const rawOffset = currSeq.offset; - U32 offCode; + U32 offBase; /* Modify the sequence depending on where endPosInSequence lies */ if (endPosInSequence >= currSeq.litLength + currSeq.matchLength) { @@ -5704,7 +6582,6 @@ ZSTD_copySequencesToSeqStoreNoBlockDelim(ZSTD_CCtx* cctx, ZSTD_sequencePosition* /* Move to the next sequence */ endPosInSequence -= currSeq.litLength + currSeq.matchLength; startPosInSequence = 0; - idx++; } else { /* This is the final (partial) sequence we're adding from inSeqs, and endPosInSequence does not reach the end of the match. So, we have to split the sequence */ @@ -5744,58 +6621,113 @@ ZSTD_copySequencesToSeqStoreNoBlockDelim(ZSTD_CCtx* cctx, ZSTD_sequencePosition* } /* Check if this offset can be represented with a repcode */ { U32 const ll0 = (litLength == 0); - offCode = ZSTD_finalizeOffCode(rawOffset, updatedRepcodes.rep, ll0); - ZSTD_updateRep(updatedRepcodes.rep, offCode, ll0); + offBase = ZSTD_finalizeOffBase(rawOffset, updatedRepcodes.rep, ll0); + ZSTD_updateRep(updatedRepcodes.rep, offBase, ll0); } if (cctx->appliedParams.validateSequences) { seqPos->posInSrc += litLength + matchLength; - FORWARD_IF_ERROR(ZSTD_validateSequence(offCode, matchLength, seqPos->posInSrc, - cctx->appliedParams.cParams.windowLog, dictSize), + FORWARD_IF_ERROR(ZSTD_validateSequence(offBase, matchLength, cctx->appliedParams.cParams.minMatch, seqPos->posInSrc, + cctx->appliedParams.cParams.windowLog, dictSize, ZSTD_hasExtSeqProd(&cctx->appliedParams)), "Sequence validation failed"); } - DEBUGLOG(6, "Storing sequence: (of: %u, ml: %u, ll: %u)", offCode, matchLength, litLength); - RETURN_ERROR_IF(idx - seqPos->idx > cctx->seqStore.maxNbSeq, memory_allocation, + DEBUGLOG(6, "Storing sequence: (of: %u, ml: %u, ll: %u)", offBase, matchLength, litLength); + RETURN_ERROR_IF(idx - seqPos->idx >= cctx->seqStore.maxNbSeq, externalSequences_invalid, "Not enough memory allocated. Try adjusting ZSTD_c_minMatch."); - ZSTD_storeSeq(&cctx->seqStore, litLength, ip, iend, offCode, matchLength); + ZSTD_storeSeq(&cctx->seqStore, litLength, ip, iend, offBase, matchLength); ip += matchLength + litLength; + if (!finalMatchSplit) + idx++; /* Next Sequence */ } DEBUGLOG(5, "Ending seq: idx: %u (of: %u ml: %u ll: %u)", idx, inSeqs[idx].offset, inSeqs[idx].matchLength, inSeqs[idx].litLength); assert(idx == inSeqsSize || endPosInSequence <= inSeqs[idx].litLength + inSeqs[idx].matchLength); seqPos->idx = idx; seqPos->posInSequence = endPosInSequence; - ZSTD_memcpy(cctx->blockState.nextCBlock->rep, updatedRepcodes.rep, sizeof(repcodes_t)); + ZSTD_memcpy(cctx->blockState.nextCBlock->rep, updatedRepcodes.rep, sizeof(Repcodes_t)); iend -= bytesAdjustment; if (ip != iend) { /* Store any last literals */ - U32 lastLLSize = (U32)(iend - ip); + U32 const lastLLSize = (U32)(iend - ip); assert(ip <= iend); DEBUGLOG(6, "Storing last literals of size: %u", lastLLSize); ZSTD_storeLastLiterals(&cctx->seqStore, ip, lastLLSize); seqPos->posInSrc += lastLLSize; } - return bytesAdjustment; + return (size_t)(iend-istart); } -typedef size_t (*ZSTD_sequenceCopier) (ZSTD_CCtx* cctx, ZSTD_sequencePosition* seqPos, - const ZSTD_Sequence* const inSeqs, size_t inSeqsSize, - const void* src, size_t blockSize); -static ZSTD_sequenceCopier ZSTD_selectSequenceCopier(ZSTD_sequenceFormat_e mode) +/* @seqPos represents a position within @inSeqs, + * it is read and updated by this function, + * once the goal to produce a block of size @blockSize is reached. + * @return: nb of bytes consumed from @src, necessarily <= @blockSize. + */ +typedef size_t (*ZSTD_SequenceCopier_f)(ZSTD_CCtx* cctx, + ZSTD_SequencePosition* seqPos, + const ZSTD_Sequence* const inSeqs, size_t inSeqsSize, + const void* src, size_t blockSize, + ZSTD_ParamSwitch_e externalRepSearch); + +static ZSTD_SequenceCopier_f ZSTD_selectSequenceCopier(ZSTD_SequenceFormat_e mode) { - ZSTD_sequenceCopier sequenceCopier = NULL; - assert(ZSTD_cParam_withinBounds(ZSTD_c_blockDelimiters, mode)); + assert(ZSTD_cParam_withinBounds(ZSTD_c_blockDelimiters, (int)mode)); if (mode == ZSTD_sf_explicitBlockDelimiters) { - return ZSTD_copySequencesToSeqStoreExplicitBlockDelim; - } else if (mode == ZSTD_sf_noBlockDelimiters) { - return ZSTD_copySequencesToSeqStoreNoBlockDelim; + return ZSTD_transferSequences_wBlockDelim; + } + assert(mode == ZSTD_sf_noBlockDelimiters); + return ZSTD_transferSequences_noDelim; +} + +/* Discover the size of next block by searching for the delimiter. + * Note that a block delimiter **must** exist in this mode, + * otherwise it's an input error. + * The block size retrieved will be later compared to ensure it remains within bounds */ +static size_t +blockSize_explicitDelimiter(const ZSTD_Sequence* inSeqs, size_t inSeqsSize, ZSTD_SequencePosition seqPos) +{ + int end = 0; + size_t blockSize = 0; + size_t spos = seqPos.idx; + DEBUGLOG(6, "blockSize_explicitDelimiter : seq %zu / %zu", spos, inSeqsSize); + assert(spos <= inSeqsSize); + while (spos < inSeqsSize) { + end = (inSeqs[spos].offset == 0); + blockSize += inSeqs[spos].litLength + inSeqs[spos].matchLength; + if (end) { + if (inSeqs[spos].matchLength != 0) + RETURN_ERROR(externalSequences_invalid, "delimiter format error : both matchlength and offset must be == 0"); + break; + } + spos++; } - assert(sequenceCopier != NULL); - return sequenceCopier; + if (!end) + RETURN_ERROR(externalSequences_invalid, "Reached end of sequences without finding a block delimiter"); + return blockSize; } -/* Compress, block-by-block, all of the sequences given. +static size_t determine_blockSize(ZSTD_SequenceFormat_e mode, + size_t blockSize, size_t remaining, + const ZSTD_Sequence* inSeqs, size_t inSeqsSize, + ZSTD_SequencePosition seqPos) +{ + DEBUGLOG(6, "determine_blockSize : remainingSize = %zu", remaining); + if (mode == ZSTD_sf_noBlockDelimiters) { + /* Note: more a "target" block size */ + return MIN(remaining, blockSize); + } + assert(mode == ZSTD_sf_explicitBlockDelimiters); + { size_t const explicitBlockSize = blockSize_explicitDelimiter(inSeqs, inSeqsSize, seqPos); + FORWARD_IF_ERROR(explicitBlockSize, "Error while determining block size with explicit delimiters"); + if (explicitBlockSize > blockSize) + RETURN_ERROR(externalSequences_invalid, "sequences incorrectly define a too large block"); + if (explicitBlockSize > remaining) + RETURN_ERROR(externalSequences_invalid, "sequences define a frame longer than source"); + return explicitBlockSize; + } +} + +/* Compress all provided sequences, block-by-block. * * Returns the cumulative size of all compressed blocks (including their headers), * otherwise a ZSTD error. @@ -5807,15 +6739,12 @@ ZSTD_compressSequences_internal(ZSTD_CCtx* cctx, const void* src, size_t srcSize) { size_t cSize = 0; - U32 lastBlock; - size_t blockSize; - size_t compressedSeqsSize; size_t remaining = srcSize; - ZSTD_sequencePosition seqPos = {0, 0, 0}; + ZSTD_SequencePosition seqPos = {0, 0, 0}; - BYTE const* ip = (BYTE const*)src; + const BYTE* ip = (BYTE const*)src; BYTE* op = (BYTE*)dst; - ZSTD_sequenceCopier const sequenceCopier = ZSTD_selectSequenceCopier(cctx->appliedParams.blockDelimiters); + ZSTD_SequenceCopier_f const sequenceCopier = ZSTD_selectSequenceCopier(cctx->appliedParams.blockDelimiters); DEBUGLOG(4, "ZSTD_compressSequences_internal srcSize: %zu, inSeqsSize: %zu", srcSize, inSeqsSize); /* Special case: empty frame */ @@ -5829,22 +6758,29 @@ ZSTD_compressSequences_internal(ZSTD_CCtx* cctx, } while (remaining) { + size_t compressedSeqsSize; size_t cBlockSize; - size_t additionalByteAdjustment; - lastBlock = remaining <= cctx->blockSize; - blockSize = lastBlock ? (U32)remaining : (U32)cctx->blockSize; + size_t blockSize = determine_blockSize(cctx->appliedParams.blockDelimiters, + cctx->blockSizeMax, remaining, + inSeqs, inSeqsSize, seqPos); + U32 const lastBlock = (blockSize == remaining); + FORWARD_IF_ERROR(blockSize, "Error while trying to determine block size"); + assert(blockSize <= remaining); ZSTD_resetSeqStore(&cctx->seqStore); - DEBUGLOG(4, "Working on new block. Blocksize: %zu", blockSize); - additionalByteAdjustment = sequenceCopier(cctx, &seqPos, inSeqs, inSeqsSize, ip, blockSize); - FORWARD_IF_ERROR(additionalByteAdjustment, "Bad sequence copy"); - blockSize -= additionalByteAdjustment; + blockSize = sequenceCopier(cctx, + &seqPos, inSeqs, inSeqsSize, + ip, blockSize, + cctx->appliedParams.searchForExternalRepcodes); + FORWARD_IF_ERROR(blockSize, "Bad sequence copy"); /* If blocks are too small, emit as a nocompress block */ - if (blockSize < MIN_CBLOCK_SIZE+ZSTD_blockHeaderSize+1) { + /* TODO: See 3090. We reduced MIN_CBLOCK_SIZE from 3 to 2 so to compensate we are adding + * additional 1. We need to revisit and change this logic to be more consistent */ + if (blockSize < MIN_CBLOCK_SIZE+ZSTD_blockHeaderSize+1+1) { cBlockSize = ZSTD_noCompressBlock(op, dstCapacity, ip, blockSize, lastBlock); FORWARD_IF_ERROR(cBlockSize, "Nocompress block failed"); - DEBUGLOG(4, "Block too small, writing out nocompress block: cSize: %zu", cBlockSize); + DEBUGLOG(5, "Block too small (%zu): data remains uncompressed: cSize=%zu", blockSize, cBlockSize); cSize += cBlockSize; ip += blockSize; op += cBlockSize; @@ -5853,35 +6789,36 @@ ZSTD_compressSequences_internal(ZSTD_CCtx* cctx, continue; } + RETURN_ERROR_IF(dstCapacity < ZSTD_blockHeaderSize, dstSize_tooSmall, "not enough dstCapacity to write a new compressed block"); compressedSeqsSize = ZSTD_entropyCompressSeqStore(&cctx->seqStore, &cctx->blockState.prevCBlock->entropy, &cctx->blockState.nextCBlock->entropy, &cctx->appliedParams, op + ZSTD_blockHeaderSize /* Leave space for block header */, dstCapacity - ZSTD_blockHeaderSize, blockSize, - cctx->entropyWorkspace, ENTROPY_WORKSPACE_SIZE /* statically allocated in resetCCtx */, + cctx->tmpWorkspace, cctx->tmpWkspSize /* statically allocated in resetCCtx */, cctx->bmi2); FORWARD_IF_ERROR(compressedSeqsSize, "Compressing sequences of block failed"); - DEBUGLOG(4, "Compressed sequences size: %zu", compressedSeqsSize); + DEBUGLOG(5, "Compressed sequences size: %zu", compressedSeqsSize); if (!cctx->isFirstBlock && ZSTD_maybeRLE(&cctx->seqStore) && - ZSTD_isRLE((BYTE const*)src, srcSize)) { - /* We don't want to emit our first block as a RLE even if it qualifies because - * doing so will cause the decoder (cli only) to throw a "should consume all input error." - * This is only an issue for zstd <= v1.4.3 - */ + ZSTD_isRLE(ip, blockSize)) { + /* Note: don't emit the first block as RLE even if it qualifies because + * doing so will cause the decoder (cli <= v1.4.3 only) to throw an (invalid) error + * "should consume all input error." + */ compressedSeqsSize = 1; } if (compressedSeqsSize == 0) { /* ZSTD_noCompressBlock writes the block header as well */ cBlockSize = ZSTD_noCompressBlock(op, dstCapacity, ip, blockSize, lastBlock); - FORWARD_IF_ERROR(cBlockSize, "Nocompress block failed"); - DEBUGLOG(4, "Writing out nocompress block, size: %zu", cBlockSize); + FORWARD_IF_ERROR(cBlockSize, "ZSTD_noCompressBlock failed"); + DEBUGLOG(5, "Writing out nocompress block, size: %zu", cBlockSize); } else if (compressedSeqsSize == 1) { cBlockSize = ZSTD_rleCompressBlock(op, dstCapacity, *ip, blockSize, lastBlock); - FORWARD_IF_ERROR(cBlockSize, "RLE compress block failed"); - DEBUGLOG(4, "Writing out RLE block, size: %zu", cBlockSize); + FORWARD_IF_ERROR(cBlockSize, "ZSTD_rleCompressBlock failed"); + DEBUGLOG(5, "Writing out RLE block, size: %zu", cBlockSize); } else { U32 cBlockHeader; /* Error checking and repcodes update */ @@ -5893,11 +6830,10 @@ ZSTD_compressSequences_internal(ZSTD_CCtx* cctx, cBlockHeader = lastBlock + (((U32)bt_compressed)<<1) + (U32)(compressedSeqsSize << 3); MEM_writeLE24(op, cBlockHeader); cBlockSize = ZSTD_blockHeaderSize + compressedSeqsSize; - DEBUGLOG(4, "Writing out compressed block, size: %zu", cBlockSize); + DEBUGLOG(5, "Writing out compressed block, size: %zu", cBlockSize); } cSize += cBlockSize; - DEBUGLOG(4, "cSize running total: %zu", cSize); if (lastBlock) { break; @@ -5908,41 +6844,50 @@ ZSTD_compressSequences_internal(ZSTD_CCtx* cctx, dstCapacity -= cBlockSize; cctx->isFirstBlock = 0; } + DEBUGLOG(5, "cSize running total: %zu (remaining dstCapacity=%zu)", cSize, dstCapacity); } + DEBUGLOG(4, "cSize final total: %zu", cSize); return cSize; } -size_t ZSTD_compressSequences(ZSTD_CCtx* const cctx, void* dst, size_t dstCapacity, +size_t ZSTD_compressSequences(ZSTD_CCtx* cctx, + void* dst, size_t dstCapacity, const ZSTD_Sequence* inSeqs, size_t inSeqsSize, const void* src, size_t srcSize) { BYTE* op = (BYTE*)dst; size_t cSize = 0; - size_t compressedBlocksSize = 0; - size_t frameHeaderSize = 0; /* Transparent initialization stage, same as compressStream2() */ - DEBUGLOG(3, "ZSTD_compressSequences()"); + DEBUGLOG(4, "ZSTD_compressSequences (nbSeqs=%zu,dstCapacity=%zu)", inSeqsSize, dstCapacity); assert(cctx != NULL); FORWARD_IF_ERROR(ZSTD_CCtx_init_compressStream2(cctx, ZSTD_e_end, srcSize), "CCtx initialization failed"); + /* Begin writing output, starting with frame header */ - frameHeaderSize = ZSTD_writeFrameHeader(op, dstCapacity, &cctx->appliedParams, srcSize, cctx->dictID); - op += frameHeaderSize; - dstCapacity -= frameHeaderSize; - cSize += frameHeaderSize; + { size_t const frameHeaderSize = ZSTD_writeFrameHeader(op, dstCapacity, + &cctx->appliedParams, srcSize, cctx->dictID); + op += frameHeaderSize; + assert(frameHeaderSize <= dstCapacity); + dstCapacity -= frameHeaderSize; + cSize += frameHeaderSize; + } if (cctx->appliedParams.fParams.checksumFlag && srcSize) { xxh64_update(&cctx->xxhState, src, srcSize); } - /* cSize includes block header size and compressed sequences size */ - compressedBlocksSize = ZSTD_compressSequences_internal(cctx, + + /* Now generate compressed blocks */ + { size_t const cBlocksSize = ZSTD_compressSequences_internal(cctx, op, dstCapacity, inSeqs, inSeqsSize, src, srcSize); - FORWARD_IF_ERROR(compressedBlocksSize, "Compressing blocks failed!"); - cSize += compressedBlocksSize; - dstCapacity -= compressedBlocksSize; + FORWARD_IF_ERROR(cBlocksSize, "Compressing blocks failed!"); + cSize += cBlocksSize; + assert(cBlocksSize <= dstCapacity); + dstCapacity -= cBlocksSize; + } + /* Complete with frame checksum, if needed */ if (cctx->appliedParams.fParams.checksumFlag) { U32 const checksum = (U32) xxh64_digest(&cctx->xxhState); RETURN_ERROR_IF(dstCapacity<4, dstSize_tooSmall, "no room for checksum"); @@ -5951,26 +6896,557 @@ size_t ZSTD_compressSequences(ZSTD_CCtx* const cctx, void* dst, size_t dstCapaci cSize += 4; } - DEBUGLOG(3, "Final compressed size: %zu", cSize); + DEBUGLOG(4, "Final compressed size: %zu", cSize); + return cSize; +} + + +#if defined(__AVX2__) + +#include /* AVX2 intrinsics */ + +/* + * Convert 2 sequences per iteration, using AVX2 intrinsics: + * - offset -> offBase = offset + 2 + * - litLength -> (U16) litLength + * - matchLength -> (U16)(matchLength - 3) + * - rep is ignored + * Store only 8 bytes per SeqDef (offBase[4], litLength[2], mlBase[2]). + * + * At the end, instead of extracting two __m128i, + * we use _mm256_permute4x64_epi64(..., 0xE8) to move lane2 into lane1, + * then store the lower 16 bytes in one go. + * + * @returns 0 on succes, with no long length detected + * @returns > 0 if there is one long length (> 65535), + * indicating the position, and type. + */ +static size_t convertSequences_noRepcodes( + SeqDef* dstSeqs, + const ZSTD_Sequence* inSeqs, + size_t nbSequences) +{ + /* + * addition: + * For each 128-bit half: (offset+2, litLength+0, matchLength-3, rep+0) + */ + const __m256i addition = _mm256_setr_epi32( + ZSTD_REP_NUM, 0, -MINMATCH, 0, /* for sequence i */ + ZSTD_REP_NUM, 0, -MINMATCH, 0 /* for sequence i+1 */ + ); + + /* limit: check if there is a long length */ + const __m256i limit = _mm256_set1_epi32(65535); + + /* + * shuffle mask for byte-level rearrangement in each 128-bit half: + * + * Input layout (after addition) per 128-bit half: + * [ offset+2 (4 bytes) | litLength (4 bytes) | matchLength (4 bytes) | rep (4 bytes) ] + * We only need: + * offBase (4 bytes) = offset+2 + * litLength (2 bytes) = low 2 bytes of litLength + * mlBase (2 bytes) = low 2 bytes of (matchLength) + * => Bytes [0..3, 4..5, 8..9], zero the rest. + */ + const __m256i mask = _mm256_setr_epi8( + /* For the lower 128 bits => sequence i */ + 0, 1, 2, 3, /* offset+2 */ + 4, 5, /* litLength (16 bits) */ + 8, 9, /* matchLength (16 bits) */ + (BYTE)0x80, (BYTE)0x80, (BYTE)0x80, (BYTE)0x80, + (BYTE)0x80, (BYTE)0x80, (BYTE)0x80, (BYTE)0x80, + + /* For the upper 128 bits => sequence i+1 */ + 16,17,18,19, /* offset+2 */ + 20,21, /* litLength */ + 24,25, /* matchLength */ + (BYTE)0x80, (BYTE)0x80, (BYTE)0x80, (BYTE)0x80, + (BYTE)0x80, (BYTE)0x80, (BYTE)0x80, (BYTE)0x80 + ); + + /* + * Next, we'll use _mm256_permute4x64_epi64(vshf, 0xE8). + * Explanation of 0xE8 = 11101000b => [lane0, lane2, lane2, lane3]. + * So the lower 128 bits become [lane0, lane2] => combining seq0 and seq1. + */ +#define PERM_LANE_0X_E8 0xE8 /* [0,2,2,3] in lane indices */ + + size_t longLen = 0, i = 0; + + /* AVX permutation depends on the specific definition of target structures */ + ZSTD_STATIC_ASSERT(sizeof(ZSTD_Sequence) == 16); + ZSTD_STATIC_ASSERT(offsetof(ZSTD_Sequence, offset) == 0); + ZSTD_STATIC_ASSERT(offsetof(ZSTD_Sequence, litLength) == 4); + ZSTD_STATIC_ASSERT(offsetof(ZSTD_Sequence, matchLength) == 8); + ZSTD_STATIC_ASSERT(sizeof(SeqDef) == 8); + ZSTD_STATIC_ASSERT(offsetof(SeqDef, offBase) == 0); + ZSTD_STATIC_ASSERT(offsetof(SeqDef, litLength) == 4); + ZSTD_STATIC_ASSERT(offsetof(SeqDef, mlBase) == 6); + + /* Process 2 sequences per loop iteration */ + for (; i + 1 < nbSequences; i += 2) { + /* Load 2 ZSTD_Sequence (32 bytes) */ + __m256i vin = _mm256_loadu_si256((const __m256i*)(const void*)&inSeqs[i]); + + /* Add {2, 0, -3, 0} in each 128-bit half */ + __m256i vadd = _mm256_add_epi32(vin, addition); + + /* Check for long length */ + __m256i ll_cmp = _mm256_cmpgt_epi32(vadd, limit); /* 0xFFFFFFFF for element > 65535 */ + int ll_res = _mm256_movemask_epi8(ll_cmp); + + /* Shuffle bytes so each half gives us the 8 bytes we need */ + __m256i vshf = _mm256_shuffle_epi8(vadd, mask); + /* + * Now: + * Lane0 = seq0's 8 bytes + * Lane1 = 0 + * Lane2 = seq1's 8 bytes + * Lane3 = 0 + */ + + /* Permute 64-bit lanes => move Lane2 down into Lane1. */ + __m256i vperm = _mm256_permute4x64_epi64(vshf, PERM_LANE_0X_E8); + /* + * Now the lower 16 bytes (Lane0+Lane1) = [seq0, seq1]. + * The upper 16 bytes are [Lane2, Lane3] = [seq1, 0], but we won't use them. + */ + + /* Store only the lower 16 bytes => 2 SeqDef (8 bytes each) */ + _mm_storeu_si128((__m128i *)(void*)&dstSeqs[i], _mm256_castsi256_si128(vperm)); + /* + * This writes out 16 bytes total: + * - offset 0..7 => seq0 (offBase, litLength, mlBase) + * - offset 8..15 => seq1 (offBase, litLength, mlBase) + */ + + /* check (unlikely) long lengths > 65535 + * indices for lengths correspond to bits [4..7], [8..11], [20..23], [24..27] + * => combined mask = 0x0FF00FF0 + */ + if (UNLIKELY((ll_res & 0x0FF00FF0) != 0)) { + /* long length detected: let's figure out which one*/ + if (inSeqs[i].matchLength > 65535+MINMATCH) { + assert(longLen == 0); + longLen = i + 1; + } + if (inSeqs[i].litLength > 65535) { + assert(longLen == 0); + longLen = i + nbSequences + 1; + } + if (inSeqs[i+1].matchLength > 65535+MINMATCH) { + assert(longLen == 0); + longLen = i + 1 + 1; + } + if (inSeqs[i+1].litLength > 65535) { + assert(longLen == 0); + longLen = i + 1 + nbSequences + 1; + } + } + } + + /* Handle leftover if @nbSequences is odd */ + if (i < nbSequences) { + /* process last sequence */ + assert(i == nbSequences - 1); + dstSeqs[i].offBase = OFFSET_TO_OFFBASE(inSeqs[i].offset); + dstSeqs[i].litLength = (U16)inSeqs[i].litLength; + dstSeqs[i].mlBase = (U16)(inSeqs[i].matchLength - MINMATCH); + /* check (unlikely) long lengths > 65535 */ + if (UNLIKELY(inSeqs[i].matchLength > 65535+MINMATCH)) { + assert(longLen == 0); + longLen = i + 1; + } + if (UNLIKELY(inSeqs[i].litLength > 65535)) { + assert(longLen == 0); + longLen = i + nbSequences + 1; + } + } + + return longLen; +} + +/* the vector implementation could also be ported to SSSE3, + * but since this implementation is targeting modern systems (>= Sapphire Rapid), + * it's not useful to develop and maintain code for older pre-AVX2 platforms */ + +#else /* no AVX2 */ + +static size_t convertSequences_noRepcodes( + SeqDef* dstSeqs, + const ZSTD_Sequence* inSeqs, + size_t nbSequences) +{ + size_t longLen = 0; + size_t n; + for (n=0; n 65535 */ + if (UNLIKELY(inSeqs[n].matchLength > 65535+MINMATCH)) { + assert(longLen == 0); + longLen = n + 1; + } + if (UNLIKELY(inSeqs[n].litLength > 65535)) { + assert(longLen == 0); + longLen = n + nbSequences + 1; + } + } + return longLen; +} + +#endif + +/* + * Precondition: Sequences must end on an explicit Block Delimiter + * @return: 0 on success, or an error code. + * Note: Sequence validation functionality has been disabled (removed). + * This is helpful to generate a lean main pipeline, improving performance. + * It may be re-inserted later. + */ +size_t ZSTD_convertBlockSequences(ZSTD_CCtx* cctx, + const ZSTD_Sequence* const inSeqs, size_t nbSequences, + int repcodeResolution) +{ + Repcodes_t updatedRepcodes; + size_t seqNb = 0; + + DEBUGLOG(5, "ZSTD_convertBlockSequences (nbSequences = %zu)", nbSequences); + + RETURN_ERROR_IF(nbSequences >= cctx->seqStore.maxNbSeq, externalSequences_invalid, + "Not enough memory allocated. Try adjusting ZSTD_c_minMatch."); + + ZSTD_memcpy(updatedRepcodes.rep, cctx->blockState.prevCBlock->rep, sizeof(Repcodes_t)); + + /* check end condition */ + assert(nbSequences >= 1); + assert(inSeqs[nbSequences-1].matchLength == 0); + assert(inSeqs[nbSequences-1].offset == 0); + + /* Convert Sequences from public format to internal format */ + if (!repcodeResolution) { + size_t const longl = convertSequences_noRepcodes(cctx->seqStore.sequencesStart, inSeqs, nbSequences-1); + cctx->seqStore.sequences = cctx->seqStore.sequencesStart + nbSequences-1; + if (longl) { + DEBUGLOG(5, "long length"); + assert(cctx->seqStore.longLengthType == ZSTD_llt_none); + if (longl <= nbSequences-1) { + DEBUGLOG(5, "long match length detected at pos %zu", longl-1); + cctx->seqStore.longLengthType = ZSTD_llt_matchLength; + cctx->seqStore.longLengthPos = (U32)(longl-1); + } else { + DEBUGLOG(5, "long literals length detected at pos %zu", longl-nbSequences); + assert(longl <= 2* (nbSequences-1)); + cctx->seqStore.longLengthType = ZSTD_llt_literalLength; + cctx->seqStore.longLengthPos = (U32)(longl-(nbSequences-1)-1); + } + } + } else { + for (seqNb = 0; seqNb < nbSequences - 1 ; seqNb++) { + U32 const litLength = inSeqs[seqNb].litLength; + U32 const matchLength = inSeqs[seqNb].matchLength; + U32 const ll0 = (litLength == 0); + U32 const offBase = ZSTD_finalizeOffBase(inSeqs[seqNb].offset, updatedRepcodes.rep, ll0); + + DEBUGLOG(6, "Storing sequence: (of: %u, ml: %u, ll: %u)", offBase, matchLength, litLength); + ZSTD_storeSeqOnly(&cctx->seqStore, litLength, offBase, matchLength); + ZSTD_updateRep(updatedRepcodes.rep, offBase, ll0); + } + } + + /* If we skipped repcode search while parsing, we need to update repcodes now */ + if (!repcodeResolution && nbSequences > 1) { + U32* const rep = updatedRepcodes.rep; + + if (nbSequences >= 4) { + U32 lastSeqIdx = (U32)nbSequences - 2; /* index of last full sequence */ + rep[2] = inSeqs[lastSeqIdx - 2].offset; + rep[1] = inSeqs[lastSeqIdx - 1].offset; + rep[0] = inSeqs[lastSeqIdx].offset; + } else if (nbSequences == 3) { + rep[2] = rep[0]; + rep[1] = inSeqs[0].offset; + rep[0] = inSeqs[1].offset; + } else { + assert(nbSequences == 2); + rep[2] = rep[1]; + rep[1] = rep[0]; + rep[0] = inSeqs[0].offset; + } + } + + ZSTD_memcpy(cctx->blockState.nextCBlock->rep, updatedRepcodes.rep, sizeof(Repcodes_t)); + + return 0; +} + +#if defined(ZSTD_ARCH_X86_AVX2) + +BlockSummary ZSTD_get1BlockSummary(const ZSTD_Sequence* seqs, size_t nbSeqs) +{ + size_t i; + __m256i const zeroVec = _mm256_setzero_si256(); + __m256i sumVec = zeroVec; /* accumulates match+lit in 32-bit lanes */ + ZSTD_ALIGNED(32) U32 tmp[8]; /* temporary buffer for reduction */ + size_t mSum = 0, lSum = 0; + ZSTD_STATIC_ASSERT(sizeof(ZSTD_Sequence) == 16); + + /* Process 2 structs (32 bytes) at a time */ + for (i = 0; i + 2 <= nbSeqs; i += 2) { + /* Load two consecutive ZSTD_Sequence (8×4 = 32 bytes) */ + __m256i data = _mm256_loadu_si256((const __m256i*)(const void*)&seqs[i]); + /* check end of block signal */ + __m256i cmp = _mm256_cmpeq_epi32(data, zeroVec); + int cmp_res = _mm256_movemask_epi8(cmp); + /* indices for match lengths correspond to bits [8..11], [24..27] + * => combined mask = 0x0F000F00 */ + ZSTD_STATIC_ASSERT(offsetof(ZSTD_Sequence, matchLength) == 8); + if (cmp_res & 0x0F000F00) break; + /* Accumulate in sumVec */ + sumVec = _mm256_add_epi32(sumVec, data); + } + + /* Horizontal reduction */ + _mm256_store_si256((__m256i*)tmp, sumVec); + lSum = tmp[1] + tmp[5]; + mSum = tmp[2] + tmp[6]; + + /* Handle the leftover */ + for (; i < nbSeqs; i++) { + lSum += seqs[i].litLength; + mSum += seqs[i].matchLength; + if (seqs[i].matchLength == 0) break; /* end of block */ + } + + if (i==nbSeqs) { + /* reaching end of sequences: end of block signal was not present */ + BlockSummary bs; + bs.nbSequences = ERROR(externalSequences_invalid); + return bs; + } + { BlockSummary bs; + bs.nbSequences = i+1; + bs.blockSize = lSum + mSum; + bs.litSize = lSum; + return bs; + } +} + +#else + +BlockSummary ZSTD_get1BlockSummary(const ZSTD_Sequence* seqs, size_t nbSeqs) +{ + size_t totalMatchSize = 0; + size_t litSize = 0; + size_t n; + assert(seqs); + for (n=0; nappliedParams.searchForExternalRepcodes == ZSTD_ps_enable); + assert(cctx->appliedParams.searchForExternalRepcodes != ZSTD_ps_auto); + + DEBUGLOG(4, "ZSTD_compressSequencesAndLiterals_internal: nbSeqs=%zu, litSize=%zu", nbSequences, litSize); + RETURN_ERROR_IF(nbSequences == 0, externalSequences_invalid, "Requires at least 1 end-of-block"); + + /* Special case: empty frame */ + if ((nbSequences == 1) && (inSeqs[0].litLength == 0)) { + U32 const cBlockHeader24 = 1 /* last block */ + (((U32)bt_raw)<<1); + RETURN_ERROR_IF(dstCapacity<3, dstSize_tooSmall, "No room for empty frame block header"); + MEM_writeLE24(op, cBlockHeader24); + op += ZSTD_blockHeaderSize; + dstCapacity -= ZSTD_blockHeaderSize; + cSize += ZSTD_blockHeaderSize; + } + + while (nbSequences) { + size_t compressedSeqsSize, cBlockSize, conversionStatus; + BlockSummary const block = ZSTD_get1BlockSummary(inSeqs, nbSequences); + U32 const lastBlock = (block.nbSequences == nbSequences); + FORWARD_IF_ERROR(block.nbSequences, "Error while trying to determine nb of sequences for a block"); + assert(block.nbSequences <= nbSequences); + RETURN_ERROR_IF(block.litSize > litSize, externalSequences_invalid, "discrepancy: Sequences require more literals than present in buffer"); + ZSTD_resetSeqStore(&cctx->seqStore); + + conversionStatus = ZSTD_convertBlockSequences(cctx, + inSeqs, block.nbSequences, + repcodeResolution); + FORWARD_IF_ERROR(conversionStatus, "Bad sequence conversion"); + inSeqs += block.nbSequences; + nbSequences -= block.nbSequences; + remaining -= block.blockSize; + + /* Note: when blockSize is very small, other variant send it uncompressed. + * Here, we still send the sequences, because we don't have the original source to send it uncompressed. + * One could imagine in theory reproducing the source from the sequences, + * but that's complex and costly memory intensive, and goes against the objectives of this variant. */ + + RETURN_ERROR_IF(dstCapacity < ZSTD_blockHeaderSize, dstSize_tooSmall, "not enough dstCapacity to write a new compressed block"); + + compressedSeqsSize = ZSTD_entropyCompressSeqStore_internal( + op + ZSTD_blockHeaderSize /* Leave space for block header */, dstCapacity - ZSTD_blockHeaderSize, + literals, block.litSize, + &cctx->seqStore, + &cctx->blockState.prevCBlock->entropy, &cctx->blockState.nextCBlock->entropy, + &cctx->appliedParams, + cctx->tmpWorkspace, cctx->tmpWkspSize /* statically allocated in resetCCtx */, + cctx->bmi2); + FORWARD_IF_ERROR(compressedSeqsSize, "Compressing sequences of block failed"); + /* note: the spec forbids for any compressed block to be larger than maximum block size */ + if (compressedSeqsSize > cctx->blockSizeMax) compressedSeqsSize = 0; + DEBUGLOG(5, "Compressed sequences size: %zu", compressedSeqsSize); + litSize -= block.litSize; + literals = (const char*)literals + block.litSize; + + /* Note: difficult to check source for RLE block when only Literals are provided, + * but it could be considered from analyzing the sequence directly */ + + if (compressedSeqsSize == 0) { + /* Sending uncompressed blocks is out of reach, because the source is not provided. + * In theory, one could use the sequences to regenerate the source, like a decompressor, + * but it's complex, and memory hungry, killing the purpose of this variant. + * Current outcome: generate an error code. + */ + RETURN_ERROR(cannotProduce_uncompressedBlock, "ZSTD_compressSequencesAndLiterals cannot generate an uncompressed block"); + } else { + U32 cBlockHeader; + assert(compressedSeqsSize > 1); /* no RLE */ + /* Error checking and repcodes update */ + ZSTD_blockState_confirmRepcodesAndEntropyTables(&cctx->blockState); + if (cctx->blockState.prevCBlock->entropy.fse.offcode_repeatMode == FSE_repeat_valid) + cctx->blockState.prevCBlock->entropy.fse.offcode_repeatMode = FSE_repeat_check; + + /* Write block header into beginning of block*/ + cBlockHeader = lastBlock + (((U32)bt_compressed)<<1) + (U32)(compressedSeqsSize << 3); + MEM_writeLE24(op, cBlockHeader); + cBlockSize = ZSTD_blockHeaderSize + compressedSeqsSize; + DEBUGLOG(5, "Writing out compressed block, size: %zu", cBlockSize); + } + + cSize += cBlockSize; + op += cBlockSize; + dstCapacity -= cBlockSize; + cctx->isFirstBlock = 0; + DEBUGLOG(5, "cSize running total: %zu (remaining dstCapacity=%zu)", cSize, dstCapacity); + + if (lastBlock) { + assert(nbSequences == 0); + break; + } + } + + RETURN_ERROR_IF(litSize != 0, externalSequences_invalid, "literals must be entirely and exactly consumed"); + RETURN_ERROR_IF(remaining != 0, externalSequences_invalid, "Sequences must represent a total of exactly srcSize=%zu", srcSize); + DEBUGLOG(4, "cSize final total: %zu", cSize); + return cSize; +} + +size_t +ZSTD_compressSequencesAndLiterals(ZSTD_CCtx* cctx, + void* dst, size_t dstCapacity, + const ZSTD_Sequence* inSeqs, size_t inSeqsSize, + const void* literals, size_t litSize, size_t litCapacity, + size_t decompressedSize) +{ + BYTE* op = (BYTE*)dst; + size_t cSize = 0; + + /* Transparent initialization stage, same as compressStream2() */ + DEBUGLOG(4, "ZSTD_compressSequencesAndLiterals (dstCapacity=%zu)", dstCapacity); + assert(cctx != NULL); + if (litCapacity < litSize) { + RETURN_ERROR(workSpace_tooSmall, "literals buffer is not large enough: must be at least 8 bytes larger than litSize (risk of read out-of-bound)"); + } + FORWARD_IF_ERROR(ZSTD_CCtx_init_compressStream2(cctx, ZSTD_e_end, decompressedSize), "CCtx initialization failed"); + + if (cctx->appliedParams.blockDelimiters == ZSTD_sf_noBlockDelimiters) { + RETURN_ERROR(frameParameter_unsupported, "This mode is only compatible with explicit delimiters"); + } + if (cctx->appliedParams.validateSequences) { + RETURN_ERROR(parameter_unsupported, "This mode is not compatible with Sequence validation"); + } + if (cctx->appliedParams.fParams.checksumFlag) { + RETURN_ERROR(frameParameter_unsupported, "this mode is not compatible with frame checksum"); + } + + /* Begin writing output, starting with frame header */ + { size_t const frameHeaderSize = ZSTD_writeFrameHeader(op, dstCapacity, + &cctx->appliedParams, decompressedSize, cctx->dictID); + op += frameHeaderSize; + assert(frameHeaderSize <= dstCapacity); + dstCapacity -= frameHeaderSize; + cSize += frameHeaderSize; + } + + /* Now generate compressed blocks */ + { size_t const cBlocksSize = ZSTD_compressSequencesAndLiterals_internal(cctx, + op, dstCapacity, + inSeqs, inSeqsSize, + literals, litSize, decompressedSize); + FORWARD_IF_ERROR(cBlocksSize, "Compressing blocks failed!"); + cSize += cBlocksSize; + assert(cBlocksSize <= dstCapacity); + dstCapacity -= cBlocksSize; + } + + DEBUGLOG(4, "Final compressed size: %zu", cSize); return cSize; } /*====== Finalize ======*/ +static ZSTD_inBuffer inBuffer_forEndFlush(const ZSTD_CStream* zcs) +{ + const ZSTD_inBuffer nullInput = { NULL, 0, 0 }; + const int stableInput = (zcs->appliedParams.inBufferMode == ZSTD_bm_stable); + return stableInput ? zcs->expectedInBuffer : nullInput; +} + /*! ZSTD_flushStream() : * @return : amount of data remaining to flush */ size_t ZSTD_flushStream(ZSTD_CStream* zcs, ZSTD_outBuffer* output) { - ZSTD_inBuffer input = { NULL, 0, 0 }; + ZSTD_inBuffer input = inBuffer_forEndFlush(zcs); + input.size = input.pos; /* do not ingest more input during flush */ return ZSTD_compressStream2(zcs, output, &input, ZSTD_e_flush); } - size_t ZSTD_endStream(ZSTD_CStream* zcs, ZSTD_outBuffer* output) { - ZSTD_inBuffer input = { NULL, 0, 0 }; + ZSTD_inBuffer input = inBuffer_forEndFlush(zcs); size_t const remainingToFlush = ZSTD_compressStream2(zcs, output, &input, ZSTD_e_end); - FORWARD_IF_ERROR( remainingToFlush , "ZSTD_compressStream2 failed"); + FORWARD_IF_ERROR(remainingToFlush , "ZSTD_compressStream2(,,ZSTD_e_end) failed"); if (zcs->appliedParams.nbWorkers > 0) return remainingToFlush; /* minimal estimation */ /* single thread mode : attempt to calculate remaining to flush more precisely */ { size_t const lastBlockSize = zcs->frameEnded ? 0 : ZSTD_BLOCKHEADERSIZE; @@ -6046,7 +7522,7 @@ static void ZSTD_dedicatedDictSearch_revertCParams( } } -static U64 ZSTD_getCParamRowSize(U64 srcSizeHint, size_t dictSize, ZSTD_cParamMode_e mode) +static U64 ZSTD_getCParamRowSize(U64 srcSizeHint, size_t dictSize, ZSTD_CParamMode_e mode) { switch (mode) { case ZSTD_cpm_unknown: @@ -6070,8 +7546,8 @@ static U64 ZSTD_getCParamRowSize(U64 srcSizeHint, size_t dictSize, ZSTD_cParamMo * @return ZSTD_compressionParameters structure for a selected compression level, srcSize and dictSize. * Note: srcSizeHint 0 means 0, use ZSTD_CONTENTSIZE_UNKNOWN for unknown. * Use dictSize == 0 for unknown or unused. - * Note: `mode` controls how we treat the `dictSize`. See docs for `ZSTD_cParamMode_e`. */ -static ZSTD_compressionParameters ZSTD_getCParams_internal(int compressionLevel, unsigned long long srcSizeHint, size_t dictSize, ZSTD_cParamMode_e mode) + * Note: `mode` controls how we treat the `dictSize`. See docs for `ZSTD_CParamMode_e`. */ +static ZSTD_compressionParameters ZSTD_getCParams_internal(int compressionLevel, unsigned long long srcSizeHint, size_t dictSize, ZSTD_CParamMode_e mode) { U64 const rSize = ZSTD_getCParamRowSize(srcSizeHint, dictSize, mode); U32 const tableID = (rSize <= 256 KB) + (rSize <= 128 KB) + (rSize <= 16 KB); @@ -6092,7 +7568,7 @@ static ZSTD_compressionParameters ZSTD_getCParams_internal(int compressionLevel, cp.targetLength = (unsigned)(-clampedCompressionLevel); } /* refine parameters based on srcSize & dictSize */ - return ZSTD_adjustCParams_internal(cp, srcSizeHint, dictSize, mode); + return ZSTD_adjustCParams_internal(cp, srcSizeHint, dictSize, mode, ZSTD_ps_auto); } } @@ -6109,7 +7585,9 @@ ZSTD_compressionParameters ZSTD_getCParams(int compressionLevel, unsigned long l * same idea as ZSTD_getCParams() * @return a `ZSTD_parameters` structure (instead of `ZSTD_compressionParameters`). * Fields of `ZSTD_frameParameters` are set to default values */ -static ZSTD_parameters ZSTD_getParams_internal(int compressionLevel, unsigned long long srcSizeHint, size_t dictSize, ZSTD_cParamMode_e mode) { +static ZSTD_parameters +ZSTD_getParams_internal(int compressionLevel, unsigned long long srcSizeHint, size_t dictSize, ZSTD_CParamMode_e mode) +{ ZSTD_parameters params; ZSTD_compressionParameters const cParams = ZSTD_getCParams_internal(compressionLevel, srcSizeHint, dictSize, mode); DEBUGLOG(5, "ZSTD_getParams (cLevel=%i)", compressionLevel); @@ -6123,7 +7601,34 @@ static ZSTD_parameters ZSTD_getParams_internal(int compressionLevel, unsigned lo * same idea as ZSTD_getCParams() * @return a `ZSTD_parameters` structure (instead of `ZSTD_compressionParameters`). * Fields of `ZSTD_frameParameters` are set to default values */ -ZSTD_parameters ZSTD_getParams(int compressionLevel, unsigned long long srcSizeHint, size_t dictSize) { +ZSTD_parameters ZSTD_getParams(int compressionLevel, unsigned long long srcSizeHint, size_t dictSize) +{ if (srcSizeHint == 0) srcSizeHint = ZSTD_CONTENTSIZE_UNKNOWN; return ZSTD_getParams_internal(compressionLevel, srcSizeHint, dictSize, ZSTD_cpm_unknown); } + +void ZSTD_registerSequenceProducer( + ZSTD_CCtx* zc, + void* extSeqProdState, + ZSTD_sequenceProducer_F extSeqProdFunc) +{ + assert(zc != NULL); + ZSTD_CCtxParams_registerSequenceProducer( + &zc->requestedParams, extSeqProdState, extSeqProdFunc + ); +} + +void ZSTD_CCtxParams_registerSequenceProducer( + ZSTD_CCtx_params* params, + void* extSeqProdState, + ZSTD_sequenceProducer_F extSeqProdFunc) +{ + assert(params != NULL); + if (extSeqProdFunc != NULL) { + params->extSeqProdFunc = extSeqProdFunc; + params->extSeqProdState = extSeqProdState; + } else { + params->extSeqProdFunc = NULL; + params->extSeqProdState = NULL; + } +} diff --git a/lib/zstd/compress/zstd_compress_internal.h b/lib/zstd/compress/zstd_compress_internal.h index 71697a11ae30..b10978385876 100644 --- a/lib/zstd/compress/zstd_compress_internal.h +++ b/lib/zstd/compress/zstd_compress_internal.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -20,7 +21,8 @@ ***************************************/ #include "../common/zstd_internal.h" #include "zstd_cwksp.h" - +#include "../common/bits.h" /* ZSTD_highbit32, ZSTD_NbCommonBytes */ +#include "zstd_preSplit.h" /* ZSTD_SLIPBLOCK_WORKSPACESIZE */ /*-************************************* * Constants @@ -32,7 +34,7 @@ It's not a big deal though : candidate will just be sorted again. Additionally, candidate position 1 will be lost. But candidate 1 cannot hide a large tree of candidates, so it's a minimal loss. - The benefit is that ZSTD_DUBT_UNSORTED_MARK cannot be mishandled after table re-use with a different strategy. + The benefit is that ZSTD_DUBT_UNSORTED_MARK cannot be mishandled after table reuse with a different strategy. This constant is required by ZSTD_compressBlock_btlazy2() and ZSTD_reduceTable_internal() */ @@ -75,6 +77,70 @@ typedef struct { ZSTD_fseCTables_t fse; } ZSTD_entropyCTables_t; +/* ********************************************* +* Sequences * +***********************************************/ +typedef struct SeqDef_s { + U32 offBase; /* offBase == Offset + ZSTD_REP_NUM, or repcode 1,2,3 */ + U16 litLength; + U16 mlBase; /* mlBase == matchLength - MINMATCH */ +} SeqDef; + +/* Controls whether seqStore has a single "long" litLength or matchLength. See SeqStore_t. */ +typedef enum { + ZSTD_llt_none = 0, /* no longLengthType */ + ZSTD_llt_literalLength = 1, /* represents a long literal */ + ZSTD_llt_matchLength = 2 /* represents a long match */ +} ZSTD_longLengthType_e; + +typedef struct { + SeqDef* sequencesStart; + SeqDef* sequences; /* ptr to end of sequences */ + BYTE* litStart; + BYTE* lit; /* ptr to end of literals */ + BYTE* llCode; + BYTE* mlCode; + BYTE* ofCode; + size_t maxNbSeq; + size_t maxNbLit; + + /* longLengthPos and longLengthType to allow us to represent either a single litLength or matchLength + * in the seqStore that has a value larger than U16 (if it exists). To do so, we increment + * the existing value of the litLength or matchLength by 0x10000. + */ + ZSTD_longLengthType_e longLengthType; + U32 longLengthPos; /* Index of the sequence to apply long length modification to */ +} SeqStore_t; + +typedef struct { + U32 litLength; + U32 matchLength; +} ZSTD_SequenceLength; + +/* + * Returns the ZSTD_SequenceLength for the given sequences. It handles the decoding of long sequences + * indicated by longLengthPos and longLengthType, and adds MINMATCH back to matchLength. + */ +MEM_STATIC ZSTD_SequenceLength ZSTD_getSequenceLength(SeqStore_t const* seqStore, SeqDef const* seq) +{ + ZSTD_SequenceLength seqLen; + seqLen.litLength = seq->litLength; + seqLen.matchLength = seq->mlBase + MINMATCH; + if (seqStore->longLengthPos == (U32)(seq - seqStore->sequencesStart)) { + if (seqStore->longLengthType == ZSTD_llt_literalLength) { + seqLen.litLength += 0x10000; + } + if (seqStore->longLengthType == ZSTD_llt_matchLength) { + seqLen.matchLength += 0x10000; + } + } + return seqLen; +} + +const SeqStore_t* ZSTD_getSeqStore(const ZSTD_CCtx* ctx); /* compress & dictBuilder */ +int ZSTD_seqToCodes(const SeqStore_t* seqStorePtr); /* compress, dictBuilder, decodeCorpus (shouldn't get its definition from here) */ + + /* ********************************************* * Entropy buffer statistics structs and funcs * ***********************************************/ @@ -84,7 +150,7 @@ typedef struct { * hufDesSize refers to the size of huffman tree description in bytes. * This metadata is populated in ZSTD_buildBlockEntropyStats_literals() */ typedef struct { - symbolEncodingType_e hType; + SymbolEncodingType_e hType; BYTE hufDesBuffer[ZSTD_MAX_HUF_HEADER_SIZE]; size_t hufDesSize; } ZSTD_hufCTablesMetadata_t; @@ -95,9 +161,9 @@ typedef struct { * fseTablesSize refers to the size of fse tables in bytes. * This metadata is populated in ZSTD_buildBlockEntropyStats_sequences() */ typedef struct { - symbolEncodingType_e llType; - symbolEncodingType_e ofType; - symbolEncodingType_e mlType; + SymbolEncodingType_e llType; + SymbolEncodingType_e ofType; + SymbolEncodingType_e mlType; BYTE fseTablesBuffer[ZSTD_MAX_FSE_HEADERS_SIZE]; size_t fseTablesSize; size_t lastCountSize; /* This is to account for bug in 1.3.4. More detail in ZSTD_entropyCompressSeqStore_internal() */ @@ -111,12 +177,13 @@ typedef struct { /* ZSTD_buildBlockEntropyStats() : * Builds entropy for the block. * @return : 0 on success or error code */ -size_t ZSTD_buildBlockEntropyStats(seqStore_t* seqStorePtr, - const ZSTD_entropyCTables_t* prevEntropy, - ZSTD_entropyCTables_t* nextEntropy, - const ZSTD_CCtx_params* cctxParams, - ZSTD_entropyCTablesMetadata_t* entropyMetadata, - void* workspace, size_t wkspSize); +size_t ZSTD_buildBlockEntropyStats( + const SeqStore_t* seqStorePtr, + const ZSTD_entropyCTables_t* prevEntropy, + ZSTD_entropyCTables_t* nextEntropy, + const ZSTD_CCtx_params* cctxParams, + ZSTD_entropyCTablesMetadata_t* entropyMetadata, + void* workspace, size_t wkspSize); /* ******************************* * Compression internals structs * @@ -140,28 +207,29 @@ typedef struct { stopped. posInSequence <= seq[pos].litLength + seq[pos].matchLength */ size_t size; /* The number of sequences. <= capacity. */ size_t capacity; /* The capacity starting from `seq` pointer */ -} rawSeqStore_t; +} RawSeqStore_t; -UNUSED_ATTR static const rawSeqStore_t kNullRawSeqStore = {NULL, 0, 0, 0, 0}; +UNUSED_ATTR static const RawSeqStore_t kNullRawSeqStore = {NULL, 0, 0, 0, 0}; typedef struct { - int price; - U32 off; - U32 mlen; - U32 litlen; - U32 rep[ZSTD_REP_NUM]; + int price; /* price from beginning of segment to this position */ + U32 off; /* offset of previous match */ + U32 mlen; /* length of previous match */ + U32 litlen; /* nb of literals since previous match */ + U32 rep[ZSTD_REP_NUM]; /* offset history after previous match */ } ZSTD_optimal_t; typedef enum { zop_dynamic=0, zop_predef } ZSTD_OptPrice_e; +#define ZSTD_OPT_SIZE (ZSTD_OPT_NUM+3) typedef struct { /* All tables are allocated inside cctx->workspace by ZSTD_resetCCtx_internal() */ unsigned* litFreq; /* table of literals statistics, of size 256 */ unsigned* litLengthFreq; /* table of litLength statistics, of size (MaxLL+1) */ unsigned* matchLengthFreq; /* table of matchLength statistics, of size (MaxML+1) */ unsigned* offCodeFreq; /* table of offCode statistics, of size (MaxOff+1) */ - ZSTD_match_t* matchTable; /* list of found matches, of size ZSTD_OPT_NUM+1 */ - ZSTD_optimal_t* priceTable; /* All positions tracked by optimal parser, of size ZSTD_OPT_NUM+1 */ + ZSTD_match_t* matchTable; /* list of found matches, of size ZSTD_OPT_SIZE */ + ZSTD_optimal_t* priceTable; /* All positions tracked by optimal parser, of size ZSTD_OPT_SIZE */ U32 litSum; /* nb of literals */ U32 litLengthSum; /* nb of litLength codes */ @@ -173,7 +241,7 @@ typedef struct { U32 offCodeSumBasePrice; /* to compare to log2(offreq) */ ZSTD_OptPrice_e priceType; /* prices can be determined dynamically, or follow a pre-defined cost structure */ const ZSTD_entropyCTables_t* symbolCosts; /* pre-calculated dictionary statistics */ - ZSTD_paramSwitch_e literalCompressionMode; + ZSTD_ParamSwitch_e literalCompressionMode; } optState_t; typedef struct { @@ -195,11 +263,11 @@ typedef struct { #define ZSTD_WINDOW_START_INDEX 2 -typedef struct ZSTD_matchState_t ZSTD_matchState_t; +typedef struct ZSTD_MatchState_t ZSTD_MatchState_t; #define ZSTD_ROW_HASH_CACHE_SIZE 8 /* Size of prefetching hash cache for row-based matchfinder */ -struct ZSTD_matchState_t { +struct ZSTD_MatchState_t { ZSTD_window_t window; /* State for window round buffer management */ U32 loadedDictEnd; /* index of end of dictionary, within context's referential. * When loadedDictEnd != 0, a dictionary is in use, and still valid. @@ -212,28 +280,42 @@ struct ZSTD_matchState_t { U32 hashLog3; /* dispatch table for matches of len==3 : larger == faster, more memory */ U32 rowHashLog; /* For row-based matchfinder: Hashlog based on nb of rows in the hashTable.*/ - U16* tagTable; /* For row-based matchFinder: A row-based table containing the hashes and head index. */ + BYTE* tagTable; /* For row-based matchFinder: A row-based table containing the hashes and head index. */ U32 hashCache[ZSTD_ROW_HASH_CACHE_SIZE]; /* For row-based matchFinder: a cache of hashes to improve speed */ + U64 hashSalt; /* For row-based matchFinder: salts the hash for reuse of tag table */ + U32 hashSaltEntropy; /* For row-based matchFinder: collects entropy for salt generation */ U32* hashTable; U32* hashTable3; U32* chainTable; - U32 forceNonContiguous; /* Non-zero if we should force non-contiguous load for the next window update. */ + int forceNonContiguous; /* Non-zero if we should force non-contiguous load for the next window update. */ int dedicatedDictSearch; /* Indicates whether this matchState is using the * dedicated dictionary search structure. */ optState_t opt; /* optimal parser state */ - const ZSTD_matchState_t* dictMatchState; + const ZSTD_MatchState_t* dictMatchState; ZSTD_compressionParameters cParams; - const rawSeqStore_t* ldmSeqStore; + const RawSeqStore_t* ldmSeqStore; + + /* Controls prefetching in some dictMatchState matchfinders. + * This behavior is controlled from the cctx ms. + * This parameter has no effect in the cdict ms. */ + int prefetchCDictTables; + + /* When == 0, lazy match finders insert every position. + * When != 0, lazy match finders only insert positions they search. + * This allows them to skip much faster over incompressible data, + * at a small cost to compression ratio. + */ + int lazySkipping; }; typedef struct { ZSTD_compressedBlockState_t* prevCBlock; ZSTD_compressedBlockState_t* nextCBlock; - ZSTD_matchState_t matchState; + ZSTD_MatchState_t matchState; } ZSTD_blockState_t; typedef struct { @@ -260,7 +342,7 @@ typedef struct { } ldmState_t; typedef struct { - ZSTD_paramSwitch_e enableLdm; /* ZSTD_ps_enable to enable LDM. ZSTD_ps_auto by default */ + ZSTD_ParamSwitch_e enableLdm; /* ZSTD_ps_enable to enable LDM. ZSTD_ps_auto by default */ U32 hashLog; /* Log size of hashTable */ U32 bucketSizeLog; /* Log bucket size for collision resolution, at most 8 */ U32 minMatchLength; /* Minimum match length */ @@ -291,7 +373,7 @@ struct ZSTD_CCtx_params_s { * There is no guarantee that hint is close to actual source size */ ZSTD_dictAttachPref_e attachDictPref; - ZSTD_paramSwitch_e literalCompressionMode; + ZSTD_ParamSwitch_e literalCompressionMode; /* Multithreading: used to pass parameters to mtctx */ int nbWorkers; @@ -310,24 +392,54 @@ struct ZSTD_CCtx_params_s { ZSTD_bufferMode_e outBufferMode; /* Sequence compression API */ - ZSTD_sequenceFormat_e blockDelimiters; + ZSTD_SequenceFormat_e blockDelimiters; int validateSequences; - /* Block splitting */ - ZSTD_paramSwitch_e useBlockSplitter; + /* Block splitting + * @postBlockSplitter executes split analysis after sequences are produced, + * it's more accurate but consumes more resources. + * @preBlockSplitter_level splits before knowing sequences, + * it's more approximative but also cheaper. + * Valid @preBlockSplitter_level values range from 0 to 6 (included). + * 0 means auto, 1 means do not split, + * then levels are sorted in increasing cpu budget, from 2 (fastest) to 6 (slowest). + * Highest @preBlockSplitter_level combines well with @postBlockSplitter. + */ + ZSTD_ParamSwitch_e postBlockSplitter; + int preBlockSplitter_level; + + /* Adjust the max block size*/ + size_t maxBlockSize; /* Param for deciding whether to use row-based matchfinder */ - ZSTD_paramSwitch_e useRowMatchFinder; + ZSTD_ParamSwitch_e useRowMatchFinder; /* Always load a dictionary in ext-dict mode (not prefix mode)? */ int deterministicRefPrefix; /* Internal use, for createCCtxParams() and freeCCtxParams() only */ ZSTD_customMem customMem; + + /* Controls prefetching in some dictMatchState matchfinders */ + ZSTD_ParamSwitch_e prefetchCDictTables; + + /* Controls whether zstd will fall back to an internal matchfinder + * if the external matchfinder returns an error code. */ + int enableMatchFinderFallback; + + /* Parameters for the external sequence producer API. + * Users set these parameters through ZSTD_registerSequenceProducer(). + * It is not possible to set these parameters individually through the public API. */ + void* extSeqProdState; + ZSTD_sequenceProducer_F extSeqProdFunc; + + /* Controls repcode search in external sequence parsing */ + ZSTD_ParamSwitch_e searchForExternalRepcodes; }; /* typedef'd to ZSTD_CCtx_params within "zstd.h" */ #define COMPRESS_SEQUENCES_WORKSPACE_SIZE (sizeof(unsigned) * (MaxSeq + 2)) #define ENTROPY_WORKSPACE_SIZE (HUF_WORKSPACE_SIZE + COMPRESS_SEQUENCES_WORKSPACE_SIZE) +#define TMP_WORKSPACE_SIZE (MAX(ENTROPY_WORKSPACE_SIZE, ZSTD_SLIPBLOCK_WORKSPACESIZE)) /* * Indicates whether this compression proceeds directly from user-provided @@ -345,11 +457,11 @@ typedef enum { */ #define ZSTD_MAX_NB_BLOCK_SPLITS 196 typedef struct { - seqStore_t fullSeqStoreChunk; - seqStore_t firstHalfSeqStore; - seqStore_t secondHalfSeqStore; - seqStore_t currSeqStore; - seqStore_t nextSeqStore; + SeqStore_t fullSeqStoreChunk; + SeqStore_t firstHalfSeqStore; + SeqStore_t secondHalfSeqStore; + SeqStore_t currSeqStore; + SeqStore_t nextSeqStore; U32 partitions[ZSTD_MAX_NB_BLOCK_SPLITS]; ZSTD_entropyCTablesMetadata_t entropyMetadata; @@ -366,7 +478,7 @@ struct ZSTD_CCtx_s { size_t dictContentSize; ZSTD_cwksp workspace; /* manages buffer for dynamic allocations */ - size_t blockSize; + size_t blockSizeMax; unsigned long long pledgedSrcSizePlusOne; /* this way, 0 (default) == unknown */ unsigned long long consumedSrcSize; unsigned long long producedCSize; @@ -378,13 +490,14 @@ struct ZSTD_CCtx_s { int isFirstBlock; int initialized; - seqStore_t seqStore; /* sequences storage ptrs */ + SeqStore_t seqStore; /* sequences storage ptrs */ ldmState_t ldmState; /* long distance matching state */ rawSeq* ldmSequences; /* Storage for the ldm output sequences */ size_t maxNbLdmSequences; - rawSeqStore_t externSeqStore; /* Mutable reference to external sequences */ + RawSeqStore_t externSeqStore; /* Mutable reference to external sequences */ ZSTD_blockState_t blockState; - U32* entropyWorkspace; /* entropy workspace of ENTROPY_WORKSPACE_SIZE bytes */ + void* tmpWorkspace; /* used as substitute of stack space - must be aligned for S64 type */ + size_t tmpWkspSize; /* Whether we are streaming or not */ ZSTD_buffered_policy_e bufferedPolicy; @@ -404,6 +517,7 @@ struct ZSTD_CCtx_s { /* Stable in/out buffer verification */ ZSTD_inBuffer expectedInBuffer; + size_t stableIn_notConsumed; /* nb bytes within stable input buffer that are said to be consumed but are not */ size_t expectedOutBufferSize; /* Dictionary */ @@ -417,9 +531,14 @@ struct ZSTD_CCtx_s { /* Workspace for block splitter */ ZSTD_blockSplitCtx blockSplitCtx; + + /* Buffer for output from external sequence producer */ + ZSTD_Sequence* extSeqBuf; + size_t extSeqBufCapacity; }; typedef enum { ZSTD_dtlm_fast, ZSTD_dtlm_full } ZSTD_dictTableLoadMethod_e; +typedef enum { ZSTD_tfp_forCCtx, ZSTD_tfp_forCDict } ZSTD_tableFillPurpose_e; typedef enum { ZSTD_noDict = 0, @@ -441,17 +560,17 @@ typedef enum { * In this mode we take both the source size and the dictionary size * into account when selecting and adjusting the parameters. */ - ZSTD_cpm_unknown = 3, /* ZSTD_getCParams, ZSTD_getParams, ZSTD_adjustParams. + ZSTD_cpm_unknown = 3 /* ZSTD_getCParams, ZSTD_getParams, ZSTD_adjustParams. * We don't know what these parameters are for. We default to the legacy * behavior of taking both the source size and the dict size into account * when selecting and adjusting parameters. */ -} ZSTD_cParamMode_e; +} ZSTD_CParamMode_e; -typedef size_t (*ZSTD_blockCompressor) ( - ZSTD_matchState_t* bs, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +typedef size_t (*ZSTD_BlockCompressor_f) ( + ZSTD_MatchState_t* bs, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -ZSTD_blockCompressor ZSTD_selectBlockCompressor(ZSTD_strategy strat, ZSTD_paramSwitch_e rowMatchfinderMode, ZSTD_dictMode_e dictMode); +ZSTD_BlockCompressor_f ZSTD_selectBlockCompressor(ZSTD_strategy strat, ZSTD_ParamSwitch_e rowMatchfinderMode, ZSTD_dictMode_e dictMode); MEM_STATIC U32 ZSTD_LLcode(U32 litLength) @@ -497,12 +616,33 @@ MEM_STATIC int ZSTD_cParam_withinBounds(ZSTD_cParameter cParam, int value) return 1; } +/* ZSTD_selectAddr: + * @return index >= lowLimit ? candidate : backup, + * tries to force branchless codegen. */ +MEM_STATIC const BYTE* +ZSTD_selectAddr(U32 index, U32 lowLimit, const BYTE* candidate, const BYTE* backup) +{ +#if defined(__x86_64__) + __asm__ ( + "cmp %1, %2\n" + "cmova %3, %0\n" + : "+r"(candidate) + : "r"(index), "r"(lowLimit), "r"(backup) + ); + return candidate; +#else + return index >= lowLimit ? candidate : backup; +#endif +} + /* ZSTD_noCompressBlock() : * Writes uncompressed block to dst buffer from given src. * Returns the size of the block */ -MEM_STATIC size_t ZSTD_noCompressBlock (void* dst, size_t dstCapacity, const void* src, size_t srcSize, U32 lastBlock) +MEM_STATIC size_t +ZSTD_noCompressBlock(void* dst, size_t dstCapacity, const void* src, size_t srcSize, U32 lastBlock) { U32 const cBlockHeader24 = lastBlock + (((U32)bt_raw)<<1) + (U32)(srcSize << 3); + DEBUGLOG(5, "ZSTD_noCompressBlock (srcSize=%zu, dstCapacity=%zu)", srcSize, dstCapacity); RETURN_ERROR_IF(srcSize + ZSTD_blockHeaderSize > dstCapacity, dstSize_tooSmall, "dst buf too small for uncompressed block"); MEM_writeLE24(dst, cBlockHeader24); @@ -510,7 +650,8 @@ MEM_STATIC size_t ZSTD_noCompressBlock (void* dst, size_t dstCapacity, const voi return ZSTD_blockHeaderSize + srcSize; } -MEM_STATIC size_t ZSTD_rleCompressBlock (void* dst, size_t dstCapacity, BYTE src, size_t srcSize, U32 lastBlock) +MEM_STATIC size_t +ZSTD_rleCompressBlock(void* dst, size_t dstCapacity, BYTE src, size_t srcSize, U32 lastBlock) { BYTE* const op = (BYTE*)dst; U32 const cBlockHeader = lastBlock + (((U32)bt_rle)<<1) + (U32)(srcSize << 3); @@ -529,7 +670,7 @@ MEM_STATIC size_t ZSTD_minGain(size_t srcSize, ZSTD_strategy strat) { U32 const minlog = (strat>=ZSTD_btultra) ? (U32)(strat) - 1 : 6; ZSTD_STATIC_ASSERT(ZSTD_btultra == 8); - assert(ZSTD_cParam_withinBounds(ZSTD_c_strategy, strat)); + assert(ZSTD_cParam_withinBounds(ZSTD_c_strategy, (int)strat)); return (srcSize >> minlog) + 2; } @@ -565,29 +706,68 @@ ZSTD_safecopyLiterals(BYTE* op, BYTE const* ip, BYTE const* const iend, BYTE con while (ip < iend) *op++ = *ip++; } -#define ZSTD_REP_MOVE (ZSTD_REP_NUM-1) -#define STORE_REPCODE_1 STORE_REPCODE(1) -#define STORE_REPCODE_2 STORE_REPCODE(2) -#define STORE_REPCODE_3 STORE_REPCODE(3) -#define STORE_REPCODE(r) (assert((r)>=1), assert((r)<=3), (r)-1) -#define STORE_OFFSET(o) (assert((o)>0), o + ZSTD_REP_MOVE) -#define STORED_IS_OFFSET(o) ((o) > ZSTD_REP_MOVE) -#define STORED_IS_REPCODE(o) ((o) <= ZSTD_REP_MOVE) -#define STORED_OFFSET(o) (assert(STORED_IS_OFFSET(o)), (o)-ZSTD_REP_MOVE) -#define STORED_REPCODE(o) (assert(STORED_IS_REPCODE(o)), (o)+1) /* returns ID 1,2,3 */ -#define STORED_TO_OFFBASE(o) ((o)+1) -#define OFFBASE_TO_STORED(o) ((o)-1) + +#define REPCODE1_TO_OFFBASE REPCODE_TO_OFFBASE(1) +#define REPCODE2_TO_OFFBASE REPCODE_TO_OFFBASE(2) +#define REPCODE3_TO_OFFBASE REPCODE_TO_OFFBASE(3) +#define REPCODE_TO_OFFBASE(r) (assert((r)>=1), assert((r)<=ZSTD_REP_NUM), (r)) /* accepts IDs 1,2,3 */ +#define OFFSET_TO_OFFBASE(o) (assert((o)>0), o + ZSTD_REP_NUM) +#define OFFBASE_IS_OFFSET(o) ((o) > ZSTD_REP_NUM) +#define OFFBASE_IS_REPCODE(o) ( 1 <= (o) && (o) <= ZSTD_REP_NUM) +#define OFFBASE_TO_OFFSET(o) (assert(OFFBASE_IS_OFFSET(o)), (o) - ZSTD_REP_NUM) +#define OFFBASE_TO_REPCODE(o) (assert(OFFBASE_IS_REPCODE(o)), (o)) /* returns ID 1,2,3 */ + +/*! ZSTD_storeSeqOnly() : + * Store a sequence (litlen, litPtr, offBase and matchLength) into SeqStore_t. + * Literals themselves are not copied, but @litPtr is updated. + * @offBase : Users should employ macros REPCODE_TO_OFFBASE() and OFFSET_TO_OFFBASE(). + * @matchLength : must be >= MINMATCH +*/ +HINT_INLINE UNUSED_ATTR void +ZSTD_storeSeqOnly(SeqStore_t* seqStorePtr, + size_t litLength, + U32 offBase, + size_t matchLength) +{ + assert((size_t)(seqStorePtr->sequences - seqStorePtr->sequencesStart) < seqStorePtr->maxNbSeq); + + /* literal Length */ + assert(litLength <= ZSTD_BLOCKSIZE_MAX); + if (UNLIKELY(litLength>0xFFFF)) { + assert(seqStorePtr->longLengthType == ZSTD_llt_none); /* there can only be a single long length */ + seqStorePtr->longLengthType = ZSTD_llt_literalLength; + seqStorePtr->longLengthPos = (U32)(seqStorePtr->sequences - seqStorePtr->sequencesStart); + } + seqStorePtr->sequences[0].litLength = (U16)litLength; + + /* match offset */ + seqStorePtr->sequences[0].offBase = offBase; + + /* match Length */ + assert(matchLength <= ZSTD_BLOCKSIZE_MAX); + assert(matchLength >= MINMATCH); + { size_t const mlBase = matchLength - MINMATCH; + if (UNLIKELY(mlBase>0xFFFF)) { + assert(seqStorePtr->longLengthType == ZSTD_llt_none); /* there can only be a single long length */ + seqStorePtr->longLengthType = ZSTD_llt_matchLength; + seqStorePtr->longLengthPos = (U32)(seqStorePtr->sequences - seqStorePtr->sequencesStart); + } + seqStorePtr->sequences[0].mlBase = (U16)mlBase; + } + + seqStorePtr->sequences++; +} /*! ZSTD_storeSeq() : - * Store a sequence (litlen, litPtr, offCode and matchLength) into seqStore_t. - * @offBase_minus1 : Users should use employ macros STORE_REPCODE_X and STORE_OFFSET(). + * Store a sequence (litlen, litPtr, offBase and matchLength) into SeqStore_t. + * @offBase : Users should employ macros REPCODE_TO_OFFBASE() and OFFSET_TO_OFFBASE(). * @matchLength : must be >= MINMATCH - * Allowed to overread literals up to litLimit. + * Allowed to over-read literals up to litLimit. */ HINT_INLINE UNUSED_ATTR void -ZSTD_storeSeq(seqStore_t* seqStorePtr, +ZSTD_storeSeq(SeqStore_t* seqStorePtr, size_t litLength, const BYTE* literals, const BYTE* litLimit, - U32 offBase_minus1, + U32 offBase, size_t matchLength) { BYTE const* const litLimit_w = litLimit - WILDCOPY_OVERLENGTH; @@ -596,8 +776,8 @@ ZSTD_storeSeq(seqStore_t* seqStorePtr, static const BYTE* g_start = NULL; if (g_start==NULL) g_start = (const BYTE*)literals; /* note : index only works for compression within a single segment */ { U32 const pos = (U32)((const BYTE*)literals - g_start); - DEBUGLOG(6, "Cpos%7u :%3u literals, match%4u bytes at offCode%7u", - pos, (U32)litLength, (U32)matchLength, (U32)offBase_minus1); + DEBUGLOG(6, "Cpos%7u :%3u literals, match%4u bytes at offBase%7u", + pos, (U32)litLength, (U32)matchLength, (U32)offBase); } #endif assert((size_t)(seqStorePtr->sequences - seqStorePtr->sequencesStart) < seqStorePtr->maxNbSeq); @@ -607,9 +787,9 @@ ZSTD_storeSeq(seqStore_t* seqStorePtr, assert(literals + litLength <= litLimit); if (litEnd <= litLimit_w) { /* Common case we can use wildcopy. - * First copy 16 bytes, because literals are likely short. - */ - assert(WILDCOPY_OVERLENGTH >= 16); + * First copy 16 bytes, because literals are likely short. + */ + ZSTD_STATIC_ASSERT(WILDCOPY_OVERLENGTH >= 16); ZSTD_copy16(seqStorePtr->lit, literals); if (litLength > 16) { ZSTD_wildcopy(seqStorePtr->lit+16, literals+16, (ptrdiff_t)litLength-16, ZSTD_no_overlap); @@ -619,44 +799,22 @@ ZSTD_storeSeq(seqStore_t* seqStorePtr, } seqStorePtr->lit += litLength; - /* literal Length */ - if (litLength>0xFFFF) { - assert(seqStorePtr->longLengthType == ZSTD_llt_none); /* there can only be a single long length */ - seqStorePtr->longLengthType = ZSTD_llt_literalLength; - seqStorePtr->longLengthPos = (U32)(seqStorePtr->sequences - seqStorePtr->sequencesStart); - } - seqStorePtr->sequences[0].litLength = (U16)litLength; - - /* match offset */ - seqStorePtr->sequences[0].offBase = STORED_TO_OFFBASE(offBase_minus1); - - /* match Length */ - assert(matchLength >= MINMATCH); - { size_t const mlBase = matchLength - MINMATCH; - if (mlBase>0xFFFF) { - assert(seqStorePtr->longLengthType == ZSTD_llt_none); /* there can only be a single long length */ - seqStorePtr->longLengthType = ZSTD_llt_matchLength; - seqStorePtr->longLengthPos = (U32)(seqStorePtr->sequences - seqStorePtr->sequencesStart); - } - seqStorePtr->sequences[0].mlBase = (U16)mlBase; - } - - seqStorePtr->sequences++; + ZSTD_storeSeqOnly(seqStorePtr, litLength, offBase, matchLength); } /* ZSTD_updateRep() : * updates in-place @rep (array of repeat offsets) - * @offBase_minus1 : sum-type, with same numeric representation as ZSTD_storeSeq() + * @offBase : sum-type, using numeric representation of ZSTD_storeSeq() */ MEM_STATIC void -ZSTD_updateRep(U32 rep[ZSTD_REP_NUM], U32 const offBase_minus1, U32 const ll0) +ZSTD_updateRep(U32 rep[ZSTD_REP_NUM], U32 const offBase, U32 const ll0) { - if (STORED_IS_OFFSET(offBase_minus1)) { /* full offset */ + if (OFFBASE_IS_OFFSET(offBase)) { /* full offset */ rep[2] = rep[1]; rep[1] = rep[0]; - rep[0] = STORED_OFFSET(offBase_minus1); + rep[0] = OFFBASE_TO_OFFSET(offBase); } else { /* repcode */ - U32 const repCode = STORED_REPCODE(offBase_minus1) - 1 + ll0; + U32 const repCode = OFFBASE_TO_REPCODE(offBase) - 1 + ll0; if (repCode > 0) { /* note : if repCode==0, no change */ U32 const currentOffset = (repCode==ZSTD_REP_NUM) ? (rep[0] - 1) : rep[repCode]; rep[2] = (repCode >= 2) ? rep[1] : rep[2]; @@ -670,14 +828,14 @@ ZSTD_updateRep(U32 rep[ZSTD_REP_NUM], U32 const offBase_minus1, U32 const ll0) typedef struct repcodes_s { U32 rep[3]; -} repcodes_t; +} Repcodes_t; -MEM_STATIC repcodes_t -ZSTD_newRep(U32 const rep[ZSTD_REP_NUM], U32 const offBase_minus1, U32 const ll0) +MEM_STATIC Repcodes_t +ZSTD_newRep(U32 const rep[ZSTD_REP_NUM], U32 const offBase, U32 const ll0) { - repcodes_t newReps; + Repcodes_t newReps; ZSTD_memcpy(&newReps, rep, sizeof(newReps)); - ZSTD_updateRep(newReps.rep, offBase_minus1, ll0); + ZSTD_updateRep(newReps.rep, offBase, ll0); return newReps; } @@ -685,59 +843,6 @@ ZSTD_newRep(U32 const rep[ZSTD_REP_NUM], U32 const offBase_minus1, U32 const ll0 /*-************************************* * Match length counter ***************************************/ -static unsigned ZSTD_NbCommonBytes (size_t val) -{ - if (MEM_isLittleEndian()) { - if (MEM_64bits()) { -# if (__GNUC__ >= 4) - return (__builtin_ctzll((U64)val) >> 3); -# else - static const int DeBruijnBytePos[64] = { 0, 0, 0, 0, 0, 1, 1, 2, - 0, 3, 1, 3, 1, 4, 2, 7, - 0, 2, 3, 6, 1, 5, 3, 5, - 1, 3, 4, 4, 2, 5, 6, 7, - 7, 0, 1, 2, 3, 3, 4, 6, - 2, 6, 5, 5, 3, 4, 5, 6, - 7, 1, 2, 4, 6, 4, 4, 5, - 7, 2, 6, 5, 7, 6, 7, 7 }; - return DeBruijnBytePos[((U64)((val & -(long long)val) * 0x0218A392CDABBD3FULL)) >> 58]; -# endif - } else { /* 32 bits */ -# if (__GNUC__ >= 3) - return (__builtin_ctz((U32)val) >> 3); -# else - static const int DeBruijnBytePos[32] = { 0, 0, 3, 0, 3, 1, 3, 0, - 3, 2, 2, 1, 3, 2, 0, 1, - 3, 3, 1, 2, 2, 2, 2, 0, - 3, 1, 2, 0, 1, 0, 1, 1 }; - return DeBruijnBytePos[((U32)((val & -(S32)val) * 0x077CB531U)) >> 27]; -# endif - } - } else { /* Big Endian CPU */ - if (MEM_64bits()) { -# if (__GNUC__ >= 4) - return (__builtin_clzll(val) >> 3); -# else - unsigned r; - const unsigned n32 = sizeof(size_t)*4; /* calculate this way due to compiler complaining in 32-bits mode */ - if (!(val>>n32)) { r=4; } else { r=0; val>>=n32; } - if (!(val>>16)) { r+=2; val>>=8; } else { val>>=24; } - r += (!val); - return r; -# endif - } else { /* 32 bits */ -# if (__GNUC__ >= 3) - return (__builtin_clz((U32)val) >> 3); -# else - unsigned r; - if (!(val>>16)) { r=2; val>>=8; } else { r=0; val>>=24; } - r += (!val); - return r; -# endif - } } -} - - MEM_STATIC size_t ZSTD_count(const BYTE* pIn, const BYTE* pMatch, const BYTE* const pInLimit) { const BYTE* const pStart = pIn; @@ -771,8 +876,8 @@ ZSTD_count_2segments(const BYTE* ip, const BYTE* match, size_t const matchLength = ZSTD_count(ip, match, vEnd); if (match + matchLength != mEnd) return matchLength; DEBUGLOG(7, "ZSTD_count_2segments: found a 2-parts match (current length==%zu)", matchLength); - DEBUGLOG(7, "distance from match beginning to end dictionary = %zi", mEnd - match); - DEBUGLOG(7, "distance from current pos to end buffer = %zi", iEnd - ip); + DEBUGLOG(7, "distance from match beginning to end dictionary = %i", (int)(mEnd - match)); + DEBUGLOG(7, "distance from current pos to end buffer = %i", (int)(iEnd - ip)); DEBUGLOG(7, "next byte : ip==%02X, istart==%02X", ip[matchLength], *iStart); DEBUGLOG(7, "final match length = %zu", matchLength + ZSTD_count(ip+matchLength, iStart, iEnd)); return matchLength + ZSTD_count(ip+matchLength, iStart, iEnd); @@ -783,32 +888,43 @@ ZSTD_count_2segments(const BYTE* ip, const BYTE* match, * Hashes ***************************************/ static const U32 prime3bytes = 506832829U; -static U32 ZSTD_hash3(U32 u, U32 h) { return ((u << (32-24)) * prime3bytes) >> (32-h) ; } -MEM_STATIC size_t ZSTD_hash3Ptr(const void* ptr, U32 h) { return ZSTD_hash3(MEM_readLE32(ptr), h); } /* only in zstd_opt.h */ +static U32 ZSTD_hash3(U32 u, U32 h, U32 s) { assert(h <= 32); return (((u << (32-24)) * prime3bytes) ^ s) >> (32-h) ; } +MEM_STATIC size_t ZSTD_hash3Ptr(const void* ptr, U32 h) { return ZSTD_hash3(MEM_readLE32(ptr), h, 0); } /* only in zstd_opt.h */ +MEM_STATIC size_t ZSTD_hash3PtrS(const void* ptr, U32 h, U32 s) { return ZSTD_hash3(MEM_readLE32(ptr), h, s); } static const U32 prime4bytes = 2654435761U; -static U32 ZSTD_hash4(U32 u, U32 h) { return (u * prime4bytes) >> (32-h) ; } -static size_t ZSTD_hash4Ptr(const void* ptr, U32 h) { return ZSTD_hash4(MEM_read32(ptr), h); } +static U32 ZSTD_hash4(U32 u, U32 h, U32 s) { assert(h <= 32); return ((u * prime4bytes) ^ s) >> (32-h) ; } +static size_t ZSTD_hash4Ptr(const void* ptr, U32 h) { return ZSTD_hash4(MEM_readLE32(ptr), h, 0); } +static size_t ZSTD_hash4PtrS(const void* ptr, U32 h, U32 s) { return ZSTD_hash4(MEM_readLE32(ptr), h, s); } static const U64 prime5bytes = 889523592379ULL; -static size_t ZSTD_hash5(U64 u, U32 h) { return (size_t)(((u << (64-40)) * prime5bytes) >> (64-h)) ; } -static size_t ZSTD_hash5Ptr(const void* p, U32 h) { return ZSTD_hash5(MEM_readLE64(p), h); } +static size_t ZSTD_hash5(U64 u, U32 h, U64 s) { assert(h <= 64); return (size_t)((((u << (64-40)) * prime5bytes) ^ s) >> (64-h)) ; } +static size_t ZSTD_hash5Ptr(const void* p, U32 h) { return ZSTD_hash5(MEM_readLE64(p), h, 0); } +static size_t ZSTD_hash5PtrS(const void* p, U32 h, U64 s) { return ZSTD_hash5(MEM_readLE64(p), h, s); } static const U64 prime6bytes = 227718039650203ULL; -static size_t ZSTD_hash6(U64 u, U32 h) { return (size_t)(((u << (64-48)) * prime6bytes) >> (64-h)) ; } -static size_t ZSTD_hash6Ptr(const void* p, U32 h) { return ZSTD_hash6(MEM_readLE64(p), h); } +static size_t ZSTD_hash6(U64 u, U32 h, U64 s) { assert(h <= 64); return (size_t)((((u << (64-48)) * prime6bytes) ^ s) >> (64-h)) ; } +static size_t ZSTD_hash6Ptr(const void* p, U32 h) { return ZSTD_hash6(MEM_readLE64(p), h, 0); } +static size_t ZSTD_hash6PtrS(const void* p, U32 h, U64 s) { return ZSTD_hash6(MEM_readLE64(p), h, s); } static const U64 prime7bytes = 58295818150454627ULL; -static size_t ZSTD_hash7(U64 u, U32 h) { return (size_t)(((u << (64-56)) * prime7bytes) >> (64-h)) ; } -static size_t ZSTD_hash7Ptr(const void* p, U32 h) { return ZSTD_hash7(MEM_readLE64(p), h); } +static size_t ZSTD_hash7(U64 u, U32 h, U64 s) { assert(h <= 64); return (size_t)((((u << (64-56)) * prime7bytes) ^ s) >> (64-h)) ; } +static size_t ZSTD_hash7Ptr(const void* p, U32 h) { return ZSTD_hash7(MEM_readLE64(p), h, 0); } +static size_t ZSTD_hash7PtrS(const void* p, U32 h, U64 s) { return ZSTD_hash7(MEM_readLE64(p), h, s); } static const U64 prime8bytes = 0xCF1BBCDCB7A56463ULL; -static size_t ZSTD_hash8(U64 u, U32 h) { return (size_t)(((u) * prime8bytes) >> (64-h)) ; } -static size_t ZSTD_hash8Ptr(const void* p, U32 h) { return ZSTD_hash8(MEM_readLE64(p), h); } +static size_t ZSTD_hash8(U64 u, U32 h, U64 s) { assert(h <= 64); return (size_t)((((u) * prime8bytes) ^ s) >> (64-h)) ; } +static size_t ZSTD_hash8Ptr(const void* p, U32 h) { return ZSTD_hash8(MEM_readLE64(p), h, 0); } +static size_t ZSTD_hash8PtrS(const void* p, U32 h, U64 s) { return ZSTD_hash8(MEM_readLE64(p), h, s); } + MEM_STATIC FORCE_INLINE_ATTR size_t ZSTD_hashPtr(const void* p, U32 hBits, U32 mls) { + /* Although some of these hashes do support hBits up to 64, some do not. + * To be on the safe side, always avoid hBits > 32. */ + assert(hBits <= 32); + switch(mls) { default: @@ -820,6 +936,24 @@ size_t ZSTD_hashPtr(const void* p, U32 hBits, U32 mls) } } +MEM_STATIC FORCE_INLINE_ATTR +size_t ZSTD_hashPtrSalted(const void* p, U32 hBits, U32 mls, const U64 hashSalt) { + /* Although some of these hashes do support hBits up to 64, some do not. + * To be on the safe side, always avoid hBits > 32. */ + assert(hBits <= 32); + + switch(mls) + { + default: + case 4: return ZSTD_hash4PtrS(p, hBits, (U32)hashSalt); + case 5: return ZSTD_hash5PtrS(p, hBits, hashSalt); + case 6: return ZSTD_hash6PtrS(p, hBits, hashSalt); + case 7: return ZSTD_hash7PtrS(p, hBits, hashSalt); + case 8: return ZSTD_hash8PtrS(p, hBits, hashSalt); + } +} + + /* ZSTD_ipow() : * Return base^exponent. */ @@ -881,11 +1015,12 @@ MEM_STATIC U64 ZSTD_rollingHash_rotate(U64 hash, BYTE toRemove, BYTE toAdd, U64 /*-************************************* * Round buffer management ***************************************/ -#if (ZSTD_WINDOWLOG_MAX_64 > 31) -# error "ZSTD_WINDOWLOG_MAX is too large : would overflow ZSTD_CURRENT_MAX" -#endif -/* Max current allowed */ -#define ZSTD_CURRENT_MAX ((3U << 29) + (1U << ZSTD_WINDOWLOG_MAX)) +/* Max @current value allowed: + * In 32-bit mode: we want to avoid crossing the 2 GB limit, + * reducing risks of side effects in case of signed operations on indexes. + * In 64-bit mode: we want to ensure that adding the maximum job size (512 MB) + * doesn't overflow U32 index capacity (4 GB) */ +#define ZSTD_CURRENT_MAX (MEM_64bits() ? 3500U MB : 2000U MB) /* Maximum chunk size before overflow correction needs to be called again */ #define ZSTD_CHUNKSIZE_MAX \ ( ((U32)-1) /* Maximum ending current index */ \ @@ -925,7 +1060,7 @@ MEM_STATIC U32 ZSTD_window_hasExtDict(ZSTD_window_t const window) * Inspects the provided matchState and figures out what dictMode should be * passed to the compressor. */ -MEM_STATIC ZSTD_dictMode_e ZSTD_matchState_dictMode(const ZSTD_matchState_t *ms) +MEM_STATIC ZSTD_dictMode_e ZSTD_matchState_dictMode(const ZSTD_MatchState_t *ms) { return ZSTD_window_hasExtDict(ms->window) ? ZSTD_extDict : @@ -1011,7 +1146,9 @@ MEM_STATIC U32 ZSTD_window_needOverflowCorrection(ZSTD_window_t const window, * The least significant cycleLog bits of the indices must remain the same, * which may be 0. Every index up to maxDist in the past must be valid. */ -MEM_STATIC U32 ZSTD_window_correctOverflow(ZSTD_window_t* window, U32 cycleLog, +MEM_STATIC +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +U32 ZSTD_window_correctOverflow(ZSTD_window_t* window, U32 cycleLog, U32 maxDist, void const* src) { /* preemptive overflow correction: @@ -1112,7 +1249,7 @@ ZSTD_window_enforceMaxDist(ZSTD_window_t* window, const void* blockEnd, U32 maxDist, U32* loadedDictEndPtr, - const ZSTD_matchState_t** dictMatchStatePtr) + const ZSTD_MatchState_t** dictMatchStatePtr) { U32 const blockEndIdx = (U32)((BYTE const*)blockEnd - window->base); U32 const loadedDictEnd = (loadedDictEndPtr != NULL) ? *loadedDictEndPtr : 0; @@ -1157,7 +1294,7 @@ ZSTD_checkDictValidity(const ZSTD_window_t* window, const void* blockEnd, U32 maxDist, U32* loadedDictEndPtr, - const ZSTD_matchState_t** dictMatchStatePtr) + const ZSTD_MatchState_t** dictMatchStatePtr) { assert(loadedDictEndPtr != NULL); assert(dictMatchStatePtr != NULL); @@ -1167,10 +1304,15 @@ ZSTD_checkDictValidity(const ZSTD_window_t* window, (unsigned)blockEndIdx, (unsigned)maxDist, (unsigned)loadedDictEnd); assert(blockEndIdx >= loadedDictEnd); - if (blockEndIdx > loadedDictEnd + maxDist) { + if (blockEndIdx > loadedDictEnd + maxDist || loadedDictEnd != window->dictLimit) { /* On reaching window size, dictionaries are invalidated. * For simplification, if window size is reached anywhere within next block, * the dictionary is invalidated for the full block. + * + * We also have to invalidate the dictionary if ZSTD_window_update() has detected + * non-contiguous segments, which means that loadedDictEnd != window->dictLimit. + * loadedDictEnd may be 0, if forceWindow is true, but in that case we never use + * dictMatchState, so setting it to NULL is not a problem. */ DEBUGLOG(6, "invalidating dictionary for current block (distance > windowSize)"); *loadedDictEndPtr = 0; @@ -1199,9 +1341,11 @@ MEM_STATIC void ZSTD_window_init(ZSTD_window_t* window) { * forget about the extDict. Handles overlap of the prefix and extDict. * Returns non-zero if the segment is contiguous. */ -MEM_STATIC U32 ZSTD_window_update(ZSTD_window_t* window, - void const* src, size_t srcSize, - int forceNonContiguous) +MEM_STATIC +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +U32 ZSTD_window_update(ZSTD_window_t* window, + const void* src, size_t srcSize, + int forceNonContiguous) { BYTE const* const ip = (BYTE const*)src; U32 contiguous = 1; @@ -1228,8 +1372,9 @@ MEM_STATIC U32 ZSTD_window_update(ZSTD_window_t* window, /* if input and dictionary overlap : reduce dictionary (area presumed modified by input) */ if ( (ip+srcSize > window->dictBase + window->lowLimit) & (ip < window->dictBase + window->dictLimit)) { - ptrdiff_t const highInputIdx = (ip + srcSize) - window->dictBase; - U32 const lowLimitMax = (highInputIdx > (ptrdiff_t)window->dictLimit) ? window->dictLimit : (U32)highInputIdx; + size_t const highInputIdx = (size_t)((ip + srcSize) - window->dictBase); + U32 const lowLimitMax = (highInputIdx > (size_t)window->dictLimit) ? window->dictLimit : (U32)highInputIdx; + assert(highInputIdx < UINT_MAX); window->lowLimit = lowLimitMax; DEBUGLOG(5, "Overlapping extDict and input : new lowLimit = %u", window->lowLimit); } @@ -1239,7 +1384,7 @@ MEM_STATIC U32 ZSTD_window_update(ZSTD_window_t* window, /* * Returns the lowest allowed match index. It may either be in the ext-dict or the prefix. */ -MEM_STATIC U32 ZSTD_getLowestMatchIndex(const ZSTD_matchState_t* ms, U32 curr, unsigned windowLog) +MEM_STATIC U32 ZSTD_getLowestMatchIndex(const ZSTD_MatchState_t* ms, U32 curr, unsigned windowLog) { U32 const maxDistance = 1U << windowLog; U32 const lowestValid = ms->window.lowLimit; @@ -1256,7 +1401,7 @@ MEM_STATIC U32 ZSTD_getLowestMatchIndex(const ZSTD_matchState_t* ms, U32 curr, u /* * Returns the lowest allowed match index in the prefix. */ -MEM_STATIC U32 ZSTD_getLowestPrefixIndex(const ZSTD_matchState_t* ms, U32 curr, unsigned windowLog) +MEM_STATIC U32 ZSTD_getLowestPrefixIndex(const ZSTD_MatchState_t* ms, U32 curr, unsigned windowLog) { U32 const maxDistance = 1U << windowLog; U32 const lowestValid = ms->window.dictLimit; @@ -1269,6 +1414,13 @@ MEM_STATIC U32 ZSTD_getLowestPrefixIndex(const ZSTD_matchState_t* ms, U32 curr, return matchLowest; } +/* index_safety_check: + * intentional underflow : ensure repIndex isn't overlapping dict + prefix + * @return 1 if values are not overlapping, + * 0 otherwise */ +MEM_STATIC int ZSTD_index_overlap_check(const U32 prefixLowestIndex, const U32 repIndex) { + return ((U32)((prefixLowestIndex-1) - repIndex) >= 3); +} /* debug functions */ @@ -1302,7 +1454,42 @@ MEM_STATIC void ZSTD_debugTable(const U32* table, U32 max) #endif +/* Short Cache */ + +/* Normally, zstd matchfinders follow this flow: + * 1. Compute hash at ip + * 2. Load index from hashTable[hash] + * 3. Check if *ip == *(base + index) + * In dictionary compression, loading *(base + index) is often an L2 or even L3 miss. + * + * Short cache is an optimization which allows us to avoid step 3 most of the time + * when the data doesn't actually match. With short cache, the flow becomes: + * 1. Compute (hash, currentTag) at ip. currentTag is an 8-bit independent hash at ip. + * 2. Load (index, matchTag) from hashTable[hash]. See ZSTD_writeTaggedIndex to understand how this works. + * 3. Only if currentTag == matchTag, check *ip == *(base + index). Otherwise, continue. + * + * Currently, short cache is only implemented in CDict hashtables. Thus, its use is limited to + * dictMatchState matchfinders. + */ +#define ZSTD_SHORT_CACHE_TAG_BITS 8 +#define ZSTD_SHORT_CACHE_TAG_MASK ((1u << ZSTD_SHORT_CACHE_TAG_BITS) - 1) + +/* Helper function for ZSTD_fillHashTable and ZSTD_fillDoubleHashTable. + * Unpacks hashAndTag into (hash, tag), then packs (index, tag) into hashTable[hash]. */ +MEM_STATIC void ZSTD_writeTaggedIndex(U32* const hashTable, size_t hashAndTag, U32 index) { + size_t const hash = hashAndTag >> ZSTD_SHORT_CACHE_TAG_BITS; + U32 const tag = (U32)(hashAndTag & ZSTD_SHORT_CACHE_TAG_MASK); + assert(index >> (32 - ZSTD_SHORT_CACHE_TAG_BITS) == 0); + hashTable[hash] = (index << ZSTD_SHORT_CACHE_TAG_BITS) | tag; +} +/* Helper function for short cache matchfinders. + * Unpacks tag1 and tag2 from lower bits of packedTag1 and packedTag2, then checks if the tags match. */ +MEM_STATIC int ZSTD_comparePackedTags(size_t packedTag1, size_t packedTag2) { + U32 const tag1 = packedTag1 & ZSTD_SHORT_CACHE_TAG_MASK; + U32 const tag2 = packedTag2 & ZSTD_SHORT_CACHE_TAG_MASK; + return tag1 == tag2; +} /* =============================================================== * Shared internal declarations @@ -1319,6 +1506,25 @@ size_t ZSTD_loadCEntropy(ZSTD_compressedBlockState_t* bs, void* workspace, void ZSTD_reset_compressedBlockState(ZSTD_compressedBlockState_t* bs); +typedef struct { + U32 idx; /* Index in array of ZSTD_Sequence */ + U32 posInSequence; /* Position within sequence at idx */ + size_t posInSrc; /* Number of bytes given by sequences provided so far */ +} ZSTD_SequencePosition; + +/* for benchmark */ +size_t ZSTD_convertBlockSequences(ZSTD_CCtx* cctx, + const ZSTD_Sequence* const inSeqs, size_t nbSequences, + int const repcodeResolution); + +typedef struct { + size_t nbSequences; + size_t blockSize; + size_t litSize; +} BlockSummary; + +BlockSummary ZSTD_get1BlockSummary(const ZSTD_Sequence* seqs, size_t nbSeqs); + /* ============================================================== * Private declarations * These prototypes shall only be called from within lib/compress @@ -1330,7 +1536,7 @@ void ZSTD_reset_compressedBlockState(ZSTD_compressedBlockState_t* bs); * Note: srcSizeHint == 0 means 0! */ ZSTD_compressionParameters ZSTD_getCParamsFromCCtxParams( - const ZSTD_CCtx_params* CCtxParams, U64 srcSizeHint, size_t dictSize, ZSTD_cParamMode_e mode); + const ZSTD_CCtx_params* CCtxParams, U64 srcSizeHint, size_t dictSize, ZSTD_CParamMode_e mode); /*! ZSTD_initCStream_internal() : * Private use only. Init streaming operation. @@ -1342,7 +1548,7 @@ size_t ZSTD_initCStream_internal(ZSTD_CStream* zcs, const ZSTD_CDict* cdict, const ZSTD_CCtx_params* params, unsigned long long pledgedSrcSize); -void ZSTD_resetSeqStore(seqStore_t* ssPtr); +void ZSTD_resetSeqStore(SeqStore_t* ssPtr); /*! ZSTD_getCParamsFromCDict() : * as the name implies */ @@ -1381,11 +1587,10 @@ size_t ZSTD_writeLastEmptyBlock(void* dst, size_t dstCapacity); * This cannot be used when long range matching is enabled. * Zstd will use these sequences, and pass the literals to a secondary block * compressor. - * @return : An error code on failure. * NOTE: seqs are not verified! Invalid sequences can cause out-of-bounds memory * access and data corruption. */ -size_t ZSTD_referenceExternalSequences(ZSTD_CCtx* cctx, rawSeq* seq, size_t nbSeq); +void ZSTD_referenceExternalSequences(ZSTD_CCtx* cctx, rawSeq* seq, size_t nbSeq); /* ZSTD_cycleLog() : * condition for correct operation : hashLog > 1 */ @@ -1396,4 +1601,28 @@ U32 ZSTD_cycleLog(U32 hashLog, ZSTD_strategy strat); */ void ZSTD_CCtx_trace(ZSTD_CCtx* cctx, size_t extraCSize); +/* Returns 1 if an external sequence producer is registered, otherwise returns 0. */ +MEM_STATIC int ZSTD_hasExtSeqProd(const ZSTD_CCtx_params* params) { + return params->extSeqProdFunc != NULL; +} + +/* =============================================================== + * Deprecated definitions that are still used internally to avoid + * deprecation warnings. These functions are exactly equivalent to + * their public variants, but avoid the deprecation warnings. + * =============================================================== */ + +size_t ZSTD_compressBegin_usingCDict_deprecated(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict); + +size_t ZSTD_compressContinue_public(ZSTD_CCtx* cctx, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize); + +size_t ZSTD_compressEnd_public(ZSTD_CCtx* cctx, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize); + +size_t ZSTD_compressBlock_deprecated(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize); + + #endif /* ZSTD_COMPRESS_H */ diff --git a/lib/zstd/compress/zstd_compress_literals.c b/lib/zstd/compress/zstd_compress_literals.c index 52b0a8059aba..ec39b4299b6f 100644 --- a/lib/zstd/compress/zstd_compress_literals.c +++ b/lib/zstd/compress/zstd_compress_literals.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -13,11 +14,36 @@ ***************************************/ #include "zstd_compress_literals.h" + +/* ************************************************************** +* Debug Traces +****************************************************************/ +#if DEBUGLEVEL >= 2 + +static size_t showHexa(const void* src, size_t srcSize) +{ + const BYTE* const ip = (const BYTE*)src; + size_t u; + for (u=0; u31) + (srcSize>4095); + DEBUGLOG(5, "ZSTD_noCompressLiterals: srcSize=%zu, dstCapacity=%zu", srcSize, dstCapacity); + RETURN_ERROR_IF(srcSize + flSize > dstCapacity, dstSize_tooSmall, ""); switch(flSize) @@ -36,16 +62,30 @@ size_t ZSTD_noCompressLiterals (void* dst, size_t dstCapacity, const void* src, } ZSTD_memcpy(ostart + flSize, src, srcSize); - DEBUGLOG(5, "Raw literals: %u -> %u", (U32)srcSize, (U32)(srcSize + flSize)); + DEBUGLOG(5, "Raw (uncompressed) literals: %u -> %u", (U32)srcSize, (U32)(srcSize + flSize)); return srcSize + flSize; } +static int allBytesIdentical(const void* src, size_t srcSize) +{ + assert(srcSize >= 1); + assert(src != NULL); + { const BYTE b = ((const BYTE*)src)[0]; + size_t p; + for (p=1; p31) + (srcSize>4095); - (void)dstCapacity; /* dstCapacity already guaranteed to be >=4, hence large enough */ + assert(dstCapacity >= 4); (void)dstCapacity; + assert(allBytesIdentical(src, srcSize)); switch(flSize) { @@ -63,28 +103,51 @@ size_t ZSTD_compressRleLiteralsBlock (void* dst, size_t dstCapacity, const void* } ostart[flSize] = *(const BYTE*)src; - DEBUGLOG(5, "RLE literals: %u -> %u", (U32)srcSize, (U32)flSize + 1); + DEBUGLOG(5, "RLE : Repeated Literal (%02X: %u times) -> %u bytes encoded", ((const BYTE*)src)[0], (U32)srcSize, (U32)flSize + 1); return flSize+1; } -size_t ZSTD_compressLiterals (ZSTD_hufCTables_t const* prevHuf, - ZSTD_hufCTables_t* nextHuf, - ZSTD_strategy strategy, int disableLiteralCompression, - void* dst, size_t dstCapacity, - const void* src, size_t srcSize, - void* entropyWorkspace, size_t entropyWorkspaceSize, - const int bmi2, - unsigned suspectUncompressible) +/* ZSTD_minLiteralsToCompress() : + * returns minimal amount of literals + * for literal compression to even be attempted. + * Minimum is made tighter as compression strategy increases. + */ +static size_t +ZSTD_minLiteralsToCompress(ZSTD_strategy strategy, HUF_repeat huf_repeat) +{ + assert((int)strategy >= 0); + assert((int)strategy <= 9); + /* btultra2 : min 8 bytes; + * then 2x larger for each successive compression strategy + * max threshold 64 bytes */ + { int const shift = MIN(9-(int)strategy, 3); + size_t const mintc = (huf_repeat == HUF_repeat_valid) ? 6 : (size_t)8 << shift; + DEBUGLOG(7, "minLiteralsToCompress = %zu", mintc); + return mintc; + } +} + +size_t ZSTD_compressLiterals ( + void* dst, size_t dstCapacity, + const void* src, size_t srcSize, + void* entropyWorkspace, size_t entropyWorkspaceSize, + const ZSTD_hufCTables_t* prevHuf, + ZSTD_hufCTables_t* nextHuf, + ZSTD_strategy strategy, + int disableLiteralCompression, + int suspectUncompressible, + int bmi2) { - size_t const minGain = ZSTD_minGain(srcSize, strategy); size_t const lhSize = 3 + (srcSize >= 1 KB) + (srcSize >= 16 KB); BYTE* const ostart = (BYTE*)dst; U32 singleStream = srcSize < 256; - symbolEncodingType_e hType = set_compressed; + SymbolEncodingType_e hType = set_compressed; size_t cLitSize; - DEBUGLOG(5,"ZSTD_compressLiterals (disableLiteralCompression=%i srcSize=%u)", - disableLiteralCompression, (U32)srcSize); + DEBUGLOG(5,"ZSTD_compressLiterals (disableLiteralCompression=%i, srcSize=%u, dstCapacity=%zu)", + disableLiteralCompression, (U32)srcSize, dstCapacity); + + DEBUGLOG(6, "Completed literals listing (%zu bytes)", showHexa(src, srcSize)); /* Prepare nextEntropy assuming reusing the existing table */ ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); @@ -92,40 +155,51 @@ size_t ZSTD_compressLiterals (ZSTD_hufCTables_t const* prevHuf, if (disableLiteralCompression) return ZSTD_noCompressLiterals(dst, dstCapacity, src, srcSize); - /* small ? don't even attempt compression (speed opt) */ -# define COMPRESS_LITERALS_SIZE_MIN 63 - { size_t const minLitSize = (prevHuf->repeatMode == HUF_repeat_valid) ? 6 : COMPRESS_LITERALS_SIZE_MIN; - if (srcSize <= minLitSize) return ZSTD_noCompressLiterals(dst, dstCapacity, src, srcSize); - } + /* if too small, don't even attempt compression (speed opt) */ + if (srcSize < ZSTD_minLiteralsToCompress(strategy, prevHuf->repeatMode)) + return ZSTD_noCompressLiterals(dst, dstCapacity, src, srcSize); RETURN_ERROR_IF(dstCapacity < lhSize+1, dstSize_tooSmall, "not enough space for compression"); { HUF_repeat repeat = prevHuf->repeatMode; - int const preferRepeat = strategy < ZSTD_lazy ? srcSize <= 1024 : 0; + int const flags = 0 + | (bmi2 ? HUF_flags_bmi2 : 0) + | (strategy < ZSTD_lazy && srcSize <= 1024 ? HUF_flags_preferRepeat : 0) + | (strategy >= HUF_OPTIMAL_DEPTH_THRESHOLD ? HUF_flags_optimalDepth : 0) + | (suspectUncompressible ? HUF_flags_suspectUncompressible : 0); + + typedef size_t (*huf_compress_f)(void*, size_t, const void*, size_t, unsigned, unsigned, void*, size_t, HUF_CElt*, HUF_repeat*, int); + huf_compress_f huf_compress; if (repeat == HUF_repeat_valid && lhSize == 3) singleStream = 1; - cLitSize = singleStream ? - HUF_compress1X_repeat( - ostart+lhSize, dstCapacity-lhSize, src, srcSize, - HUF_SYMBOLVALUE_MAX, HUF_TABLELOG_DEFAULT, entropyWorkspace, entropyWorkspaceSize, - (HUF_CElt*)nextHuf->CTable, &repeat, preferRepeat, bmi2, suspectUncompressible) : - HUF_compress4X_repeat( - ostart+lhSize, dstCapacity-lhSize, src, srcSize, - HUF_SYMBOLVALUE_MAX, HUF_TABLELOG_DEFAULT, entropyWorkspace, entropyWorkspaceSize, - (HUF_CElt*)nextHuf->CTable, &repeat, preferRepeat, bmi2, suspectUncompressible); + huf_compress = singleStream ? HUF_compress1X_repeat : HUF_compress4X_repeat; + cLitSize = huf_compress(ostart+lhSize, dstCapacity-lhSize, + src, srcSize, + HUF_SYMBOLVALUE_MAX, LitHufLog, + entropyWorkspace, entropyWorkspaceSize, + (HUF_CElt*)nextHuf->CTable, + &repeat, flags); + DEBUGLOG(5, "%zu literals compressed into %zu bytes (before header)", srcSize, cLitSize); if (repeat != HUF_repeat_none) { /* reused the existing table */ - DEBUGLOG(5, "Reusing previous huffman table"); + DEBUGLOG(5, "reusing statistics from previous huffman block"); hType = set_repeat; } } - if ((cLitSize==0) || (cLitSize >= srcSize - minGain) || ERR_isError(cLitSize)) { - ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); - return ZSTD_noCompressLiterals(dst, dstCapacity, src, srcSize); - } + { size_t const minGain = ZSTD_minGain(srcSize, strategy); + if ((cLitSize==0) || (cLitSize >= srcSize - minGain) || ERR_isError(cLitSize)) { + ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); + return ZSTD_noCompressLiterals(dst, dstCapacity, src, srcSize); + } } if (cLitSize==1) { - ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); - return ZSTD_compressRleLiteralsBlock(dst, dstCapacity, src, srcSize); - } + /* A return value of 1 signals that the alphabet consists of a single symbol. + * However, in some rare circumstances, it could be the compressed size (a single byte). + * For that outcome to have a chance to happen, it's necessary that `srcSize < 8`. + * (it's also necessary to not generate statistics). + * Therefore, in such a case, actively check that all bytes are identical. */ + if ((srcSize >= 8) || allBytesIdentical(src, srcSize)) { + ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); + return ZSTD_compressRleLiteralsBlock(dst, dstCapacity, src, srcSize); + } } if (hType == set_compressed) { /* using a newly constructed table */ @@ -136,16 +210,19 @@ size_t ZSTD_compressLiterals (ZSTD_hufCTables_t const* prevHuf, switch(lhSize) { case 3: /* 2 - 2 - 10 - 10 */ - { U32 const lhc = hType + ((!singleStream) << 2) + ((U32)srcSize<<4) + ((U32)cLitSize<<14); + if (!singleStream) assert(srcSize >= MIN_LITERALS_FOR_4_STREAMS); + { U32 const lhc = hType + ((U32)(!singleStream) << 2) + ((U32)srcSize<<4) + ((U32)cLitSize<<14); MEM_writeLE24(ostart, lhc); break; } case 4: /* 2 - 2 - 14 - 14 */ + assert(srcSize >= MIN_LITERALS_FOR_4_STREAMS); { U32 const lhc = hType + (2 << 2) + ((U32)srcSize<<4) + ((U32)cLitSize<<18); MEM_writeLE32(ostart, lhc); break; } case 5: /* 2 - 2 - 18 - 18 */ + assert(srcSize >= MIN_LITERALS_FOR_4_STREAMS); { U32 const lhc = hType + (3 << 2) + ((U32)srcSize<<4) + ((U32)cLitSize<<22); MEM_writeLE32(ostart, lhc); ostart[4] = (BYTE)(cLitSize >> 10); diff --git a/lib/zstd/compress/zstd_compress_literals.h b/lib/zstd/compress/zstd_compress_literals.h index 9775fb97cb70..a2a85d6b69e5 100644 --- a/lib/zstd/compress/zstd_compress_literals.h +++ b/lib/zstd/compress/zstd_compress_literals.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -16,16 +17,24 @@ size_t ZSTD_noCompressLiterals (void* dst, size_t dstCapacity, const void* src, size_t srcSize); +/* ZSTD_compressRleLiteralsBlock() : + * Conditions : + * - All bytes in @src are identical + * - dstCapacity >= 4 */ size_t ZSTD_compressRleLiteralsBlock (void* dst, size_t dstCapacity, const void* src, size_t srcSize); -/* If suspectUncompressible then some sampling checks will be run to potentially skip huffman coding */ -size_t ZSTD_compressLiterals (ZSTD_hufCTables_t const* prevHuf, - ZSTD_hufCTables_t* nextHuf, - ZSTD_strategy strategy, int disableLiteralCompression, - void* dst, size_t dstCapacity, +/* ZSTD_compressLiterals(): + * @entropyWorkspace: must be aligned on 4-bytes boundaries + * @entropyWorkspaceSize : must be >= HUF_WORKSPACE_SIZE + * @suspectUncompressible: sampling checks, to potentially skip huffman coding + */ +size_t ZSTD_compressLiterals (void* dst, size_t dstCapacity, const void* src, size_t srcSize, void* entropyWorkspace, size_t entropyWorkspaceSize, - const int bmi2, - unsigned suspectUncompressible); + const ZSTD_hufCTables_t* prevHuf, + ZSTD_hufCTables_t* nextHuf, + ZSTD_strategy strategy, int disableLiteralCompression, + int suspectUncompressible, + int bmi2); #endif /* ZSTD_COMPRESS_LITERALS_H */ diff --git a/lib/zstd/compress/zstd_compress_sequences.c b/lib/zstd/compress/zstd_compress_sequences.c index 21ddc1b37acf..256980c9d85a 100644 --- a/lib/zstd/compress/zstd_compress_sequences.c +++ b/lib/zstd/compress/zstd_compress_sequences.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -58,7 +59,7 @@ static unsigned ZSTD_useLowProbCount(size_t const nbSeq) { /* Heuristic: This should cover most blocks <= 16K and * start to fade out after 16K to about 32K depending on - * comprssibility. + * compressibility. */ return nbSeq >= 2048; } @@ -153,20 +154,20 @@ size_t ZSTD_crossEntropyCost(short const* norm, unsigned accuracyLog, return cost >> 8; } -symbolEncodingType_e +SymbolEncodingType_e ZSTD_selectEncodingType( FSE_repeat* repeatMode, unsigned const* count, unsigned const max, size_t const mostFrequent, size_t nbSeq, unsigned const FSELog, FSE_CTable const* prevCTable, short const* defaultNorm, U32 defaultNormLog, - ZSTD_defaultPolicy_e const isDefaultAllowed, + ZSTD_DefaultPolicy_e const isDefaultAllowed, ZSTD_strategy const strategy) { ZSTD_STATIC_ASSERT(ZSTD_defaultDisallowed == 0 && ZSTD_defaultAllowed != 0); if (mostFrequent == nbSeq) { *repeatMode = FSE_repeat_none; if (isDefaultAllowed && nbSeq <= 2) { - /* Prefer set_basic over set_rle when there are 2 or less symbols, + /* Prefer set_basic over set_rle when there are 2 or fewer symbols, * since RLE uses 1 byte, but set_basic uses 5-6 bits per symbol. * If basic encoding isn't possible, always choose RLE. */ @@ -241,7 +242,7 @@ typedef struct { size_t ZSTD_buildCTable(void* dst, size_t dstCapacity, - FSE_CTable* nextCTable, U32 FSELog, symbolEncodingType_e type, + FSE_CTable* nextCTable, U32 FSELog, SymbolEncodingType_e type, unsigned* count, U32 max, const BYTE* codeTable, size_t nbSeq, const S16* defaultNorm, U32 defaultNormLog, U32 defaultMax, @@ -293,7 +294,7 @@ ZSTD_encodeSequences_body( FSE_CTable const* CTable_MatchLength, BYTE const* mlCodeTable, FSE_CTable const* CTable_OffsetBits, BYTE const* ofCodeTable, FSE_CTable const* CTable_LitLength, BYTE const* llCodeTable, - seqDef const* sequences, size_t nbSeq, int longOffsets) + SeqDef const* sequences, size_t nbSeq, int longOffsets) { BIT_CStream_t blockStream; FSE_CState_t stateMatchLength; @@ -387,7 +388,7 @@ ZSTD_encodeSequences_default( FSE_CTable const* CTable_MatchLength, BYTE const* mlCodeTable, FSE_CTable const* CTable_OffsetBits, BYTE const* ofCodeTable, FSE_CTable const* CTable_LitLength, BYTE const* llCodeTable, - seqDef const* sequences, size_t nbSeq, int longOffsets) + SeqDef const* sequences, size_t nbSeq, int longOffsets) { return ZSTD_encodeSequences_body(dst, dstCapacity, CTable_MatchLength, mlCodeTable, @@ -405,7 +406,7 @@ ZSTD_encodeSequences_bmi2( FSE_CTable const* CTable_MatchLength, BYTE const* mlCodeTable, FSE_CTable const* CTable_OffsetBits, BYTE const* ofCodeTable, FSE_CTable const* CTable_LitLength, BYTE const* llCodeTable, - seqDef const* sequences, size_t nbSeq, int longOffsets) + SeqDef const* sequences, size_t nbSeq, int longOffsets) { return ZSTD_encodeSequences_body(dst, dstCapacity, CTable_MatchLength, mlCodeTable, @@ -421,7 +422,7 @@ size_t ZSTD_encodeSequences( FSE_CTable const* CTable_MatchLength, BYTE const* mlCodeTable, FSE_CTable const* CTable_OffsetBits, BYTE const* ofCodeTable, FSE_CTable const* CTable_LitLength, BYTE const* llCodeTable, - seqDef const* sequences, size_t nbSeq, int longOffsets, int bmi2) + SeqDef const* sequences, size_t nbSeq, int longOffsets, int bmi2) { DEBUGLOG(5, "ZSTD_encodeSequences: dstCapacity = %u", (unsigned)dstCapacity); #if DYNAMIC_BMI2 diff --git a/lib/zstd/compress/zstd_compress_sequences.h b/lib/zstd/compress/zstd_compress_sequences.h index 7991364c2f71..14fdccb6547f 100644 --- a/lib/zstd/compress/zstd_compress_sequences.h +++ b/lib/zstd/compress/zstd_compress_sequences.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -11,26 +12,27 @@ #ifndef ZSTD_COMPRESS_SEQUENCES_H #define ZSTD_COMPRESS_SEQUENCES_H +#include "zstd_compress_internal.h" /* SeqDef */ #include "../common/fse.h" /* FSE_repeat, FSE_CTable */ -#include "../common/zstd_internal.h" /* symbolEncodingType_e, ZSTD_strategy */ +#include "../common/zstd_internal.h" /* SymbolEncodingType_e, ZSTD_strategy */ typedef enum { ZSTD_defaultDisallowed = 0, ZSTD_defaultAllowed = 1 -} ZSTD_defaultPolicy_e; +} ZSTD_DefaultPolicy_e; -symbolEncodingType_e +SymbolEncodingType_e ZSTD_selectEncodingType( FSE_repeat* repeatMode, unsigned const* count, unsigned const max, size_t const mostFrequent, size_t nbSeq, unsigned const FSELog, FSE_CTable const* prevCTable, short const* defaultNorm, U32 defaultNormLog, - ZSTD_defaultPolicy_e const isDefaultAllowed, + ZSTD_DefaultPolicy_e const isDefaultAllowed, ZSTD_strategy const strategy); size_t ZSTD_buildCTable(void* dst, size_t dstCapacity, - FSE_CTable* nextCTable, U32 FSELog, symbolEncodingType_e type, + FSE_CTable* nextCTable, U32 FSELog, SymbolEncodingType_e type, unsigned* count, U32 max, const BYTE* codeTable, size_t nbSeq, const S16* defaultNorm, U32 defaultNormLog, U32 defaultMax, @@ -42,7 +44,7 @@ size_t ZSTD_encodeSequences( FSE_CTable const* CTable_MatchLength, BYTE const* mlCodeTable, FSE_CTable const* CTable_OffsetBits, BYTE const* ofCodeTable, FSE_CTable const* CTable_LitLength, BYTE const* llCodeTable, - seqDef const* sequences, size_t nbSeq, int longOffsets, int bmi2); + SeqDef const* sequences, size_t nbSeq, int longOffsets, int bmi2); size_t ZSTD_fseBitCost( FSE_CTable const* ctable, diff --git a/lib/zstd/compress/zstd_compress_superblock.c b/lib/zstd/compress/zstd_compress_superblock.c index 17d836cc84e8..dc12d64e935c 100644 --- a/lib/zstd/compress/zstd_compress_superblock.c +++ b/lib/zstd/compress/zstd_compress_superblock.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -36,13 +37,14 @@ * If it is set_compressed, first sub-block's literals section will be Treeless_Literals_Block * and the following sub-blocks' literals sections will be Treeless_Literals_Block. * @return : compressed size of literals section of a sub-block - * Or 0 if it unable to compress. + * Or 0 if unable to compress. * Or error code */ -static size_t ZSTD_compressSubBlock_literal(const HUF_CElt* hufTable, - const ZSTD_hufCTablesMetadata_t* hufMetadata, - const BYTE* literals, size_t litSize, - void* dst, size_t dstSize, - const int bmi2, int writeEntropy, int* entropyWritten) +static size_t +ZSTD_compressSubBlock_literal(const HUF_CElt* hufTable, + const ZSTD_hufCTablesMetadata_t* hufMetadata, + const BYTE* literals, size_t litSize, + void* dst, size_t dstSize, + const int bmi2, int writeEntropy, int* entropyWritten) { size_t const header = writeEntropy ? 200 : 0; size_t const lhSize = 3 + (litSize >= (1 KB - header)) + (litSize >= (16 KB - header)); @@ -50,11 +52,9 @@ static size_t ZSTD_compressSubBlock_literal(const HUF_CElt* hufTable, BYTE* const oend = ostart + dstSize; BYTE* op = ostart + lhSize; U32 const singleStream = lhSize == 3; - symbolEncodingType_e hType = writeEntropy ? hufMetadata->hType : set_repeat; + SymbolEncodingType_e hType = writeEntropy ? hufMetadata->hType : set_repeat; size_t cLitSize = 0; - (void)bmi2; /* TODO bmi2... */ - DEBUGLOG(5, "ZSTD_compressSubBlock_literal (litSize=%zu, lhSize=%zu, writeEntropy=%d)", litSize, lhSize, writeEntropy); *entropyWritten = 0; @@ -76,9 +76,9 @@ static size_t ZSTD_compressSubBlock_literal(const HUF_CElt* hufTable, DEBUGLOG(5, "ZSTD_compressSubBlock_literal (hSize=%zu)", hufMetadata->hufDesSize); } - /* TODO bmi2 */ - { const size_t cSize = singleStream ? HUF_compress1X_usingCTable(op, oend-op, literals, litSize, hufTable) - : HUF_compress4X_usingCTable(op, oend-op, literals, litSize, hufTable); + { int const flags = bmi2 ? HUF_flags_bmi2 : 0; + const size_t cSize = singleStream ? HUF_compress1X_usingCTable(op, (size_t)(oend-op), literals, litSize, hufTable, flags) + : HUF_compress4X_usingCTable(op, (size_t)(oend-op), literals, litSize, hufTable, flags); op += cSize; cLitSize += cSize; if (cSize == 0 || ERR_isError(cSize)) { @@ -103,7 +103,7 @@ static size_t ZSTD_compressSubBlock_literal(const HUF_CElt* hufTable, switch(lhSize) { case 3: /* 2 - 2 - 10 - 10 */ - { U32 const lhc = hType + ((!singleStream) << 2) + ((U32)litSize<<4) + ((U32)cLitSize<<14); + { U32 const lhc = hType + ((U32)(!singleStream) << 2) + ((U32)litSize<<4) + ((U32)cLitSize<<14); MEM_writeLE24(ostart, lhc); break; } @@ -123,26 +123,30 @@ static size_t ZSTD_compressSubBlock_literal(const HUF_CElt* hufTable, } *entropyWritten = 1; DEBUGLOG(5, "Compressed literals: %u -> %u", (U32)litSize, (U32)(op-ostart)); - return op-ostart; + return (size_t)(op-ostart); } -static size_t ZSTD_seqDecompressedSize(seqStore_t const* seqStore, const seqDef* sequences, size_t nbSeq, size_t litSize, int lastSequence) { - const seqDef* const sstart = sequences; - const seqDef* const send = sequences + nbSeq; - const seqDef* sp = sstart; +static size_t +ZSTD_seqDecompressedSize(SeqStore_t const* seqStore, + const SeqDef* sequences, size_t nbSeqs, + size_t litSize, int lastSubBlock) +{ size_t matchLengthSum = 0; size_t litLengthSum = 0; - (void)(litLengthSum); /* suppress unused variable warning on some environments */ - while (send-sp > 0) { - ZSTD_sequenceLength const seqLen = ZSTD_getSequenceLength(seqStore, sp); + size_t n; + for (n=0; ncParams.windowLog > STREAM_ACCUMULATOR_MIN; BYTE* const ostart = (BYTE*)dst; @@ -176,14 +181,14 @@ static size_t ZSTD_compressSubBlock_sequences(const ZSTD_fseCTables_t* fseTables /* Sequences Header */ RETURN_ERROR_IF((oend-op) < 3 /*max nbSeq Size*/ + 1 /*seqHead*/, dstSize_tooSmall, ""); - if (nbSeq < 0x7F) + if (nbSeq < 128) *op++ = (BYTE)nbSeq; else if (nbSeq < LONGNBSEQ) op[0] = (BYTE)((nbSeq>>8) + 0x80), op[1] = (BYTE)nbSeq, op+=2; else op[0]=0xFF, MEM_writeLE16(op+1, (U16)(nbSeq - LONGNBSEQ)), op+=3; if (nbSeq==0) { - return op - ostart; + return (size_t)(op - ostart); } /* seqHead : flags for FSE encoding type */ @@ -205,7 +210,7 @@ static size_t ZSTD_compressSubBlock_sequences(const ZSTD_fseCTables_t* fseTables } { size_t const bitstreamSize = ZSTD_encodeSequences( - op, oend - op, + op, (size_t)(oend - op), fseTables->matchlengthCTable, mlCode, fseTables->offcodeCTable, ofCode, fseTables->litlengthCTable, llCode, @@ -249,7 +254,7 @@ static size_t ZSTD_compressSubBlock_sequences(const ZSTD_fseCTables_t* fseTables #endif *entropyWritten = 1; - return op - ostart; + return (size_t)(op - ostart); } /* ZSTD_compressSubBlock() : @@ -258,7 +263,7 @@ static size_t ZSTD_compressSubBlock_sequences(const ZSTD_fseCTables_t* fseTables * Or 0 if it failed to compress. */ static size_t ZSTD_compressSubBlock(const ZSTD_entropyCTables_t* entropy, const ZSTD_entropyCTablesMetadata_t* entropyMetadata, - const seqDef* sequences, size_t nbSeq, + const SeqDef* sequences, size_t nbSeq, const BYTE* literals, size_t litSize, const BYTE* llCode, const BYTE* mlCode, const BYTE* ofCode, const ZSTD_CCtx_params* cctxParams, @@ -275,7 +280,8 @@ static size_t ZSTD_compressSubBlock(const ZSTD_entropyCTables_t* entropy, litSize, nbSeq, writeLitEntropy, writeSeqEntropy, lastBlock); { size_t cLitSize = ZSTD_compressSubBlock_literal((const HUF_CElt*)entropy->huf.CTable, &entropyMetadata->hufMetadata, literals, litSize, - op, oend-op, bmi2, writeLitEntropy, litEntropyWritten); + op, (size_t)(oend-op), + bmi2, writeLitEntropy, litEntropyWritten); FORWARD_IF_ERROR(cLitSize, "ZSTD_compressSubBlock_literal failed"); if (cLitSize == 0) return 0; op += cLitSize; @@ -285,18 +291,18 @@ static size_t ZSTD_compressSubBlock(const ZSTD_entropyCTables_t* entropy, sequences, nbSeq, llCode, mlCode, ofCode, cctxParams, - op, oend-op, + op, (size_t)(oend-op), bmi2, writeSeqEntropy, seqEntropyWritten); FORWARD_IF_ERROR(cSeqSize, "ZSTD_compressSubBlock_sequences failed"); if (cSeqSize == 0) return 0; op += cSeqSize; } /* Write block header */ - { size_t cSize = (op-ostart)-ZSTD_blockHeaderSize; + { size_t cSize = (size_t)(op-ostart) - ZSTD_blockHeaderSize; U32 const cBlockHeader24 = lastBlock + (((U32)bt_compressed)<<1) + (U32)(cSize << 3); MEM_writeLE24(ostart, cBlockHeader24); } - return op-ostart; + return (size_t)(op-ostart); } static size_t ZSTD_estimateSubBlockSize_literal(const BYTE* literals, size_t litSize, @@ -322,7 +328,7 @@ static size_t ZSTD_estimateSubBlockSize_literal(const BYTE* literals, size_t lit return 0; } -static size_t ZSTD_estimateSubBlockSize_symbolType(symbolEncodingType_e type, +static size_t ZSTD_estimateSubBlockSize_symbolType(SymbolEncodingType_e type, const BYTE* codeTable, unsigned maxCode, size_t nbSeq, const FSE_CTable* fseCTable, const U8* additionalBits, @@ -385,7 +391,11 @@ static size_t ZSTD_estimateSubBlockSize_sequences(const BYTE* ofCodeTable, return cSeqSizeEstimate + sequencesSectionHeaderSize; } -static size_t ZSTD_estimateSubBlockSize(const BYTE* literals, size_t litSize, +typedef struct { + size_t estLitSize; + size_t estBlockSize; +} EstimatedBlockSize; +static EstimatedBlockSize ZSTD_estimateSubBlockSize(const BYTE* literals, size_t litSize, const BYTE* ofCodeTable, const BYTE* llCodeTable, const BYTE* mlCodeTable, @@ -393,15 +403,17 @@ static size_t ZSTD_estimateSubBlockSize(const BYTE* literals, size_t litSize, const ZSTD_entropyCTables_t* entropy, const ZSTD_entropyCTablesMetadata_t* entropyMetadata, void* workspace, size_t wkspSize, - int writeLitEntropy, int writeSeqEntropy) { - size_t cSizeEstimate = 0; - cSizeEstimate += ZSTD_estimateSubBlockSize_literal(literals, litSize, - &entropy->huf, &entropyMetadata->hufMetadata, - workspace, wkspSize, writeLitEntropy); - cSizeEstimate += ZSTD_estimateSubBlockSize_sequences(ofCodeTable, llCodeTable, mlCodeTable, + int writeLitEntropy, int writeSeqEntropy) +{ + EstimatedBlockSize ebs; + ebs.estLitSize = ZSTD_estimateSubBlockSize_literal(literals, litSize, + &entropy->huf, &entropyMetadata->hufMetadata, + workspace, wkspSize, writeLitEntropy); + ebs.estBlockSize = ZSTD_estimateSubBlockSize_sequences(ofCodeTable, llCodeTable, mlCodeTable, nbSeq, &entropy->fse, &entropyMetadata->fseMetadata, workspace, wkspSize, writeSeqEntropy); - return cSizeEstimate + ZSTD_blockHeaderSize; + ebs.estBlockSize += ebs.estLitSize + ZSTD_blockHeaderSize; + return ebs; } static int ZSTD_needSequenceEntropyTables(ZSTD_fseCTablesMetadata_t const* fseMetadata) @@ -415,14 +427,57 @@ static int ZSTD_needSequenceEntropyTables(ZSTD_fseCTablesMetadata_t const* fseMe return 0; } +static size_t countLiterals(SeqStore_t const* seqStore, const SeqDef* sp, size_t seqCount) +{ + size_t n, total = 0; + assert(sp != NULL); + for (n=0; n %zu bytes", seqCount, (const void*)sp, total); + return total; +} + +#define BYTESCALE 256 + +static size_t sizeBlockSequences(const SeqDef* sp, size_t nbSeqs, + size_t targetBudget, size_t avgLitCost, size_t avgSeqCost, + int firstSubBlock) +{ + size_t n, budget = 0, inSize=0; + /* entropy headers */ + size_t const headerSize = (size_t)firstSubBlock * 120 * BYTESCALE; /* generous estimate */ + assert(firstSubBlock==0 || firstSubBlock==1); + budget += headerSize; + + /* first sequence => at least one sequence*/ + budget += sp[0].litLength * avgLitCost + avgSeqCost; + if (budget > targetBudget) return 1; + inSize = sp[0].litLength + (sp[0].mlBase+MINMATCH); + + /* loop over sequences */ + for (n=1; n targetBudget) + /* though continue to expand until the sub-block is deemed compressible */ + && (budget < inSize * BYTESCALE) ) + break; + } + + return n; +} + /* ZSTD_compressSubBlock_multi() : * Breaks super-block into multiple sub-blocks and compresses them. - * Entropy will be written to the first block. - * The following blocks will use repeat mode to compress. - * All sub-blocks are compressed blocks (no raw or rle blocks). - * @return : compressed size of the super block (which is multiple ZSTD blocks) - * Or 0 if it failed to compress. */ -static size_t ZSTD_compressSubBlock_multi(const seqStore_t* seqStorePtr, + * Entropy will be written into the first block. + * The following blocks use repeat_mode to compress. + * Sub-blocks are all compressed, except the last one when beneficial. + * @return : compressed size of the super block (which features multiple ZSTD blocks) + * or 0 if it failed to compress. */ +static size_t ZSTD_compressSubBlock_multi(const SeqStore_t* seqStorePtr, const ZSTD_compressedBlockState_t* prevCBlock, ZSTD_compressedBlockState_t* nextCBlock, const ZSTD_entropyCTablesMetadata_t* entropyMetadata, @@ -432,12 +487,14 @@ static size_t ZSTD_compressSubBlock_multi(const seqStore_t* seqStorePtr, const int bmi2, U32 lastBlock, void* workspace, size_t wkspSize) { - const seqDef* const sstart = seqStorePtr->sequencesStart; - const seqDef* const send = seqStorePtr->sequences; - const seqDef* sp = sstart; + const SeqDef* const sstart = seqStorePtr->sequencesStart; + const SeqDef* const send = seqStorePtr->sequences; + const SeqDef* sp = sstart; /* tracks progresses within seqStorePtr->sequences */ + size_t const nbSeqs = (size_t)(send - sstart); const BYTE* const lstart = seqStorePtr->litStart; const BYTE* const lend = seqStorePtr->lit; const BYTE* lp = lstart; + size_t const nbLiterals = (size_t)(lend - lstart); BYTE const* ip = (BYTE const*)src; BYTE const* const iend = ip + srcSize; BYTE* const ostart = (BYTE*)dst; @@ -446,112 +503,171 @@ static size_t ZSTD_compressSubBlock_multi(const seqStore_t* seqStorePtr, const BYTE* llCodePtr = seqStorePtr->llCode; const BYTE* mlCodePtr = seqStorePtr->mlCode; const BYTE* ofCodePtr = seqStorePtr->ofCode; - size_t targetCBlockSize = cctxParams->targetCBlockSize; - size_t litSize, seqCount; - int writeLitEntropy = entropyMetadata->hufMetadata.hType == set_compressed; + size_t const minTarget = ZSTD_TARGETCBLOCKSIZE_MIN; /* enforce minimum size, to reduce undesirable side effects */ + size_t const targetCBlockSize = MAX(minTarget, cctxParams->targetCBlockSize); + int writeLitEntropy = (entropyMetadata->hufMetadata.hType == set_compressed); int writeSeqEntropy = 1; - int lastSequence = 0; - - DEBUGLOG(5, "ZSTD_compressSubBlock_multi (litSize=%u, nbSeq=%u)", - (unsigned)(lend-lp), (unsigned)(send-sstart)); - - litSize = 0; - seqCount = 0; - do { - size_t cBlockSizeEstimate = 0; - if (sstart == send) { - lastSequence = 1; - } else { - const seqDef* const sequence = sp + seqCount; - lastSequence = sequence == send - 1; - litSize += ZSTD_getSequenceLength(seqStorePtr, sequence).litLength; - seqCount++; - } - if (lastSequence) { - assert(lp <= lend); - assert(litSize <= (size_t)(lend - lp)); - litSize = (size_t)(lend - lp); + + DEBUGLOG(5, "ZSTD_compressSubBlock_multi (srcSize=%u, litSize=%u, nbSeq=%u)", + (unsigned)srcSize, (unsigned)(lend-lstart), (unsigned)(send-sstart)); + + /* let's start by a general estimation for the full block */ + if (nbSeqs > 0) { + EstimatedBlockSize const ebs = + ZSTD_estimateSubBlockSize(lp, nbLiterals, + ofCodePtr, llCodePtr, mlCodePtr, nbSeqs, + &nextCBlock->entropy, entropyMetadata, + workspace, wkspSize, + writeLitEntropy, writeSeqEntropy); + /* quick estimation */ + size_t const avgLitCost = nbLiterals ? (ebs.estLitSize * BYTESCALE) / nbLiterals : BYTESCALE; + size_t const avgSeqCost = ((ebs.estBlockSize - ebs.estLitSize) * BYTESCALE) / nbSeqs; + const size_t nbSubBlocks = MAX((ebs.estBlockSize + (targetCBlockSize/2)) / targetCBlockSize, 1); + size_t n, avgBlockBudget, blockBudgetSupp=0; + avgBlockBudget = (ebs.estBlockSize * BYTESCALE) / nbSubBlocks; + DEBUGLOG(5, "estimated fullblock size=%u bytes ; avgLitCost=%.2f ; avgSeqCost=%.2f ; targetCBlockSize=%u, nbSubBlocks=%u ; avgBlockBudget=%.0f bytes", + (unsigned)ebs.estBlockSize, (double)avgLitCost/BYTESCALE, (double)avgSeqCost/BYTESCALE, + (unsigned)targetCBlockSize, (unsigned)nbSubBlocks, (double)avgBlockBudget/BYTESCALE); + /* simplification: if estimates states that the full superblock doesn't compress, just bail out immediately + * this will result in the production of a single uncompressed block covering @srcSize.*/ + if (ebs.estBlockSize > srcSize) return 0; + + /* compress and write sub-blocks */ + assert(nbSubBlocks>0); + for (n=0; n < nbSubBlocks-1; n++) { + /* determine nb of sequences for current sub-block + nbLiterals from next sequence */ + size_t const seqCount = sizeBlockSequences(sp, (size_t)(send-sp), + avgBlockBudget + blockBudgetSupp, avgLitCost, avgSeqCost, n==0); + /* if reached last sequence : break to last sub-block (simplification) */ + assert(seqCount <= (size_t)(send-sp)); + if (sp + seqCount == send) break; + assert(seqCount > 0); + /* compress sub-block */ + { int litEntropyWritten = 0; + int seqEntropyWritten = 0; + size_t litSize = countLiterals(seqStorePtr, sp, seqCount); + const size_t decompressedSize = + ZSTD_seqDecompressedSize(seqStorePtr, sp, seqCount, litSize, 0); + size_t const cSize = ZSTD_compressSubBlock(&nextCBlock->entropy, entropyMetadata, + sp, seqCount, + lp, litSize, + llCodePtr, mlCodePtr, ofCodePtr, + cctxParams, + op, (size_t)(oend-op), + bmi2, writeLitEntropy, writeSeqEntropy, + &litEntropyWritten, &seqEntropyWritten, + 0); + FORWARD_IF_ERROR(cSize, "ZSTD_compressSubBlock failed"); + + /* check compressibility, update state components */ + if (cSize > 0 && cSize < decompressedSize) { + DEBUGLOG(5, "Committed sub-block compressing %u bytes => %u bytes", + (unsigned)decompressedSize, (unsigned)cSize); + assert(ip + decompressedSize <= iend); + ip += decompressedSize; + lp += litSize; + op += cSize; + llCodePtr += seqCount; + mlCodePtr += seqCount; + ofCodePtr += seqCount; + /* Entropy only needs to be written once */ + if (litEntropyWritten) { + writeLitEntropy = 0; + } + if (seqEntropyWritten) { + writeSeqEntropy = 0; + } + sp += seqCount; + blockBudgetSupp = 0; + } } + /* otherwise : do not compress yet, coalesce current sub-block with following one */ } - /* I think there is an optimization opportunity here. - * Calling ZSTD_estimateSubBlockSize for every sequence can be wasteful - * since it recalculates estimate from scratch. - * For example, it would recount literal distribution and symbol codes every time. - */ - cBlockSizeEstimate = ZSTD_estimateSubBlockSize(lp, litSize, ofCodePtr, llCodePtr, mlCodePtr, seqCount, - &nextCBlock->entropy, entropyMetadata, - workspace, wkspSize, writeLitEntropy, writeSeqEntropy); - if (cBlockSizeEstimate > targetCBlockSize || lastSequence) { - int litEntropyWritten = 0; - int seqEntropyWritten = 0; - const size_t decompressedSize = ZSTD_seqDecompressedSize(seqStorePtr, sp, seqCount, litSize, lastSequence); - const size_t cSize = ZSTD_compressSubBlock(&nextCBlock->entropy, entropyMetadata, - sp, seqCount, - lp, litSize, - llCodePtr, mlCodePtr, ofCodePtr, - cctxParams, - op, oend-op, - bmi2, writeLitEntropy, writeSeqEntropy, - &litEntropyWritten, &seqEntropyWritten, - lastBlock && lastSequence); - FORWARD_IF_ERROR(cSize, "ZSTD_compressSubBlock failed"); - if (cSize > 0 && cSize < decompressedSize) { - DEBUGLOG(5, "Committed the sub-block"); - assert(ip + decompressedSize <= iend); - ip += decompressedSize; - sp += seqCount; - lp += litSize; - op += cSize; - llCodePtr += seqCount; - mlCodePtr += seqCount; - ofCodePtr += seqCount; - litSize = 0; - seqCount = 0; - /* Entropy only needs to be written once */ - if (litEntropyWritten) { - writeLitEntropy = 0; - } - if (seqEntropyWritten) { - writeSeqEntropy = 0; - } + } /* if (nbSeqs > 0) */ + + /* write last block */ + DEBUGLOG(5, "Generate last sub-block: %u sequences remaining", (unsigned)(send - sp)); + { int litEntropyWritten = 0; + int seqEntropyWritten = 0; + size_t litSize = (size_t)(lend - lp); + size_t seqCount = (size_t)(send - sp); + const size_t decompressedSize = + ZSTD_seqDecompressedSize(seqStorePtr, sp, seqCount, litSize, 1); + size_t const cSize = ZSTD_compressSubBlock(&nextCBlock->entropy, entropyMetadata, + sp, seqCount, + lp, litSize, + llCodePtr, mlCodePtr, ofCodePtr, + cctxParams, + op, (size_t)(oend-op), + bmi2, writeLitEntropy, writeSeqEntropy, + &litEntropyWritten, &seqEntropyWritten, + lastBlock); + FORWARD_IF_ERROR(cSize, "ZSTD_compressSubBlock failed"); + + /* update pointers, the nb of literals borrowed from next sequence must be preserved */ + if (cSize > 0 && cSize < decompressedSize) { + DEBUGLOG(5, "Last sub-block compressed %u bytes => %u bytes", + (unsigned)decompressedSize, (unsigned)cSize); + assert(ip + decompressedSize <= iend); + ip += decompressedSize; + lp += litSize; + op += cSize; + llCodePtr += seqCount; + mlCodePtr += seqCount; + ofCodePtr += seqCount; + /* Entropy only needs to be written once */ + if (litEntropyWritten) { + writeLitEntropy = 0; } + if (seqEntropyWritten) { + writeSeqEntropy = 0; + } + sp += seqCount; } - } while (!lastSequence); + } + + if (writeLitEntropy) { - DEBUGLOG(5, "ZSTD_compressSubBlock_multi has literal entropy tables unwritten"); + DEBUGLOG(5, "Literal entropy tables were never written"); ZSTD_memcpy(&nextCBlock->entropy.huf, &prevCBlock->entropy.huf, sizeof(prevCBlock->entropy.huf)); } if (writeSeqEntropy && ZSTD_needSequenceEntropyTables(&entropyMetadata->fseMetadata)) { /* If we haven't written our entropy tables, then we've violated our contract and * must emit an uncompressed block. */ - DEBUGLOG(5, "ZSTD_compressSubBlock_multi has sequence entropy tables unwritten"); + DEBUGLOG(5, "Sequence entropy tables were never written => cancel, emit an uncompressed block"); return 0; } + if (ip < iend) { - size_t const cSize = ZSTD_noCompressBlock(op, oend - op, ip, iend - ip, lastBlock); - DEBUGLOG(5, "ZSTD_compressSubBlock_multi last sub-block uncompressed, %zu bytes", (size_t)(iend - ip)); + /* some data left : last part of the block sent uncompressed */ + size_t const rSize = (size_t)((iend - ip)); + size_t const cSize = ZSTD_noCompressBlock(op, (size_t)(oend - op), ip, rSize, lastBlock); + DEBUGLOG(5, "Generate last uncompressed sub-block of %u bytes", (unsigned)(rSize)); FORWARD_IF_ERROR(cSize, "ZSTD_noCompressBlock failed"); assert(cSize != 0); op += cSize; /* We have to regenerate the repcodes because we've skipped some sequences */ if (sp < send) { - seqDef const* seq; - repcodes_t rep; + const SeqDef* seq; + Repcodes_t rep; ZSTD_memcpy(&rep, prevCBlock->rep, sizeof(rep)); for (seq = sstart; seq < sp; ++seq) { - ZSTD_updateRep(rep.rep, seq->offBase - 1, ZSTD_getSequenceLength(seqStorePtr, seq).litLength == 0); + ZSTD_updateRep(rep.rep, seq->offBase, ZSTD_getSequenceLength(seqStorePtr, seq).litLength == 0); } ZSTD_memcpy(nextCBlock->rep, &rep, sizeof(rep)); } } - DEBUGLOG(5, "ZSTD_compressSubBlock_multi compressed"); - return op-ostart; + + DEBUGLOG(5, "ZSTD_compressSubBlock_multi compressed all subBlocks: total compressed size = %u", + (unsigned)(op-ostart)); + return (size_t)(op-ostart); } size_t ZSTD_compressSuperBlock(ZSTD_CCtx* zc, void* dst, size_t dstCapacity, - void const* src, size_t srcSize, - unsigned lastBlock) { + const void* src, size_t srcSize, + unsigned lastBlock) +{ ZSTD_entropyCTablesMetadata_t entropyMetadata; FORWARD_IF_ERROR(ZSTD_buildBlockEntropyStats(&zc->seqStore, @@ -559,7 +675,7 @@ size_t ZSTD_compressSuperBlock(ZSTD_CCtx* zc, &zc->blockState.nextCBlock->entropy, &zc->appliedParams, &entropyMetadata, - zc->entropyWorkspace, ENTROPY_WORKSPACE_SIZE /* statically allocated in resetCCtx */), ""); + zc->tmpWorkspace, zc->tmpWkspSize /* statically allocated in resetCCtx */), ""); return ZSTD_compressSubBlock_multi(&zc->seqStore, zc->blockState.prevCBlock, @@ -569,5 +685,5 @@ size_t ZSTD_compressSuperBlock(ZSTD_CCtx* zc, dst, dstCapacity, src, srcSize, zc->bmi2, lastBlock, - zc->entropyWorkspace, ENTROPY_WORKSPACE_SIZE /* statically allocated in resetCCtx */); + zc->tmpWorkspace, zc->tmpWkspSize /* statically allocated in resetCCtx */); } diff --git a/lib/zstd/compress/zstd_compress_superblock.h b/lib/zstd/compress/zstd_compress_superblock.h index 224ece79546e..826bbc9e029b 100644 --- a/lib/zstd/compress/zstd_compress_superblock.h +++ b/lib/zstd/compress/zstd_compress_superblock.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the diff --git a/lib/zstd/compress/zstd_cwksp.h b/lib/zstd/compress/zstd_cwksp.h index 349fc923c355..dce42f653bae 100644 --- a/lib/zstd/compress/zstd_cwksp.h +++ b/lib/zstd/compress/zstd_cwksp.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -14,8 +15,10 @@ /*-************************************* * Dependencies ***************************************/ +#include "../common/allocations.h" /* ZSTD_customMalloc, ZSTD_customFree */ #include "../common/zstd_internal.h" - +#include "../common/portability_macros.h" +#include "../common/compiler.h" /* ZS2_isPower2 */ /*-************************************* * Constants @@ -41,8 +44,9 @@ ***************************************/ typedef enum { ZSTD_cwksp_alloc_objects, - ZSTD_cwksp_alloc_buffers, - ZSTD_cwksp_alloc_aligned + ZSTD_cwksp_alloc_aligned_init_once, + ZSTD_cwksp_alloc_aligned, + ZSTD_cwksp_alloc_buffers } ZSTD_cwksp_alloc_phase_e; /* @@ -95,8 +99,8 @@ typedef enum { * * Workspace Layout: * - * [ ... workspace ... ] - * [objects][tables ... ->] free space [<- ... aligned][<- ... buffers] + * [ ... workspace ... ] + * [objects][tables ->] free space [<- buffers][<- aligned][<- init once] * * The various objects that live in the workspace are divided into the * following categories, and are allocated separately: @@ -120,9 +124,18 @@ typedef enum { * uint32_t arrays, all of whose values are between 0 and (nextSrc - base). * Their sizes depend on the cparams. These tables are 64-byte aligned. * - * - Aligned: these buffers are used for various purposes that require 4 byte - * alignment, but don't require any initialization before they're used. These - * buffers are each aligned to 64 bytes. + * - Init once: these buffers require to be initialized at least once before + * use. They should be used when we want to skip memory initialization + * while not triggering memory checkers (like Valgrind) when reading from + * from this memory without writing to it first. + * These buffers should be used carefully as they might contain data + * from previous compressions. + * Buffers are aligned to 64 bytes. + * + * - Aligned: these buffers don't require any initialization before they're + * used. The user of the buffer should make sure they write into a buffer + * location before reading from it. + * Buffers are aligned to 64 bytes. * * - Buffers: these buffers are used for various purposes that don't require * any alignment or initialization before they're used. This means they can @@ -134,8 +147,9 @@ typedef enum { * correctly packed into the workspace buffer. That order is: * * 1. Objects - * 2. Buffers - * 3. Aligned/Tables + * 2. Init once / Tables + * 3. Aligned / Tables + * 4. Buffers / Tables * * Attempts to reserve objects of different types out of order will fail. */ @@ -147,6 +161,7 @@ typedef struct { void* tableEnd; void* tableValidEnd; void* allocStart; + void* initOnceStart; BYTE allocFailed; int workspaceOversizedDuration; @@ -159,6 +174,7 @@ typedef struct { ***************************************/ MEM_STATIC size_t ZSTD_cwksp_available_space(ZSTD_cwksp* ws); +MEM_STATIC void* ZSTD_cwksp_initialAllocStart(ZSTD_cwksp* ws); MEM_STATIC void ZSTD_cwksp_assert_internal_consistency(ZSTD_cwksp* ws) { (void)ws; @@ -168,14 +184,16 @@ MEM_STATIC void ZSTD_cwksp_assert_internal_consistency(ZSTD_cwksp* ws) { assert(ws->tableEnd <= ws->allocStart); assert(ws->tableValidEnd <= ws->allocStart); assert(ws->allocStart <= ws->workspaceEnd); + assert(ws->initOnceStart <= ZSTD_cwksp_initialAllocStart(ws)); + assert(ws->workspace <= ws->initOnceStart); } /* * Align must be a power of 2. */ -MEM_STATIC size_t ZSTD_cwksp_align(size_t size, size_t const align) { +MEM_STATIC size_t ZSTD_cwksp_align(size_t size, size_t align) { size_t const mask = align - 1; - assert((align & mask) == 0); + assert(ZSTD_isPower2(align)); return (size + mask) & ~mask; } @@ -189,7 +207,7 @@ MEM_STATIC size_t ZSTD_cwksp_align(size_t size, size_t const align) { * to figure out how much space you need for the matchState tables. Everything * else is though. * - * Do not use for sizing aligned buffers. Instead, use ZSTD_cwksp_aligned_alloc_size(). + * Do not use for sizing aligned buffers. Instead, use ZSTD_cwksp_aligned64_alloc_size(). */ MEM_STATIC size_t ZSTD_cwksp_alloc_size(size_t size) { if (size == 0) @@ -197,12 +215,16 @@ MEM_STATIC size_t ZSTD_cwksp_alloc_size(size_t size) { return size; } +MEM_STATIC size_t ZSTD_cwksp_aligned_alloc_size(size_t size, size_t alignment) { + return ZSTD_cwksp_alloc_size(ZSTD_cwksp_align(size, alignment)); +} + /* * Returns an adjusted alloc size that is the nearest larger multiple of 64 bytes. * Used to determine the number of bytes required for a given "aligned". */ -MEM_STATIC size_t ZSTD_cwksp_aligned_alloc_size(size_t size) { - return ZSTD_cwksp_alloc_size(ZSTD_cwksp_align(size, ZSTD_CWKSP_ALIGNMENT_BYTES)); +MEM_STATIC size_t ZSTD_cwksp_aligned64_alloc_size(size_t size) { + return ZSTD_cwksp_aligned_alloc_size(size, ZSTD_CWKSP_ALIGNMENT_BYTES); } /* @@ -210,14 +232,10 @@ MEM_STATIC size_t ZSTD_cwksp_aligned_alloc_size(size_t size) { * for internal purposes (currently only alignment). */ MEM_STATIC size_t ZSTD_cwksp_slack_space_required(void) { - /* For alignment, the wksp will always allocate an additional n_1=[1, 64] bytes - * to align the beginning of tables section, as well as another n_2=[0, 63] bytes - * to align the beginning of the aligned section. - * - * n_1 + n_2 == 64 bytes if the cwksp is freshly allocated, due to tables and - * aligneds being sized in multiples of 64 bytes. + /* For alignment, the wksp will always allocate an additional 2*ZSTD_CWKSP_ALIGNMENT_BYTES + * bytes to align the beginning of tables section and end of buffers; */ - size_t const slackSpace = ZSTD_CWKSP_ALIGNMENT_BYTES; + size_t const slackSpace = ZSTD_CWKSP_ALIGNMENT_BYTES * 2; return slackSpace; } @@ -229,11 +247,23 @@ MEM_STATIC size_t ZSTD_cwksp_slack_space_required(void) { MEM_STATIC size_t ZSTD_cwksp_bytes_to_align_ptr(void* ptr, const size_t alignBytes) { size_t const alignBytesMask = alignBytes - 1; size_t const bytes = (alignBytes - ((size_t)ptr & (alignBytesMask))) & alignBytesMask; - assert((alignBytes & alignBytesMask) == 0); - assert(bytes != ZSTD_CWKSP_ALIGNMENT_BYTES); + assert(ZSTD_isPower2(alignBytes)); + assert(bytes < alignBytes); return bytes; } +/* + * Returns the initial value for allocStart which is used to determine the position from + * which we can allocate from the end of the workspace. + */ +MEM_STATIC void* ZSTD_cwksp_initialAllocStart(ZSTD_cwksp* ws) +{ + char* endPtr = (char*)ws->workspaceEnd; + assert(ZSTD_isPower2(ZSTD_CWKSP_ALIGNMENT_BYTES)); + endPtr = endPtr - ((size_t)endPtr % ZSTD_CWKSP_ALIGNMENT_BYTES); + return (void*)endPtr; +} + /* * Internal function. Do not use directly. * Reserves the given number of bytes within the aligned/buffer segment of the wksp, @@ -246,7 +276,7 @@ ZSTD_cwksp_reserve_internal_buffer_space(ZSTD_cwksp* ws, size_t const bytes) { void* const alloc = (BYTE*)ws->allocStart - bytes; void* const bottom = ws->tableEnd; - DEBUGLOG(5, "cwksp: reserving %p %zd bytes, %zd bytes remaining", + DEBUGLOG(5, "cwksp: reserving [0x%p]:%zd bytes; %zd bytes remaining", alloc, bytes, ZSTD_cwksp_available_space(ws) - bytes); ZSTD_cwksp_assert_internal_consistency(ws); assert(alloc >= bottom); @@ -274,27 +304,16 @@ ZSTD_cwksp_internal_advance_phase(ZSTD_cwksp* ws, ZSTD_cwksp_alloc_phase_e phase { assert(phase >= ws->phase); if (phase > ws->phase) { - /* Going from allocating objects to allocating buffers */ - if (ws->phase < ZSTD_cwksp_alloc_buffers && - phase >= ZSTD_cwksp_alloc_buffers) { + /* Going from allocating objects to allocating initOnce / tables */ + if (ws->phase < ZSTD_cwksp_alloc_aligned_init_once && + phase >= ZSTD_cwksp_alloc_aligned_init_once) { ws->tableValidEnd = ws->objectEnd; - } + ws->initOnceStart = ZSTD_cwksp_initialAllocStart(ws); - /* Going from allocating buffers to allocating aligneds/tables */ - if (ws->phase < ZSTD_cwksp_alloc_aligned && - phase >= ZSTD_cwksp_alloc_aligned) { - { /* Align the start of the "aligned" to 64 bytes. Use [1, 64] bytes. */ - size_t const bytesToAlign = - ZSTD_CWKSP_ALIGNMENT_BYTES - ZSTD_cwksp_bytes_to_align_ptr(ws->allocStart, ZSTD_CWKSP_ALIGNMENT_BYTES); - DEBUGLOG(5, "reserving aligned alignment addtl space: %zu", bytesToAlign); - ZSTD_STATIC_ASSERT((ZSTD_CWKSP_ALIGNMENT_BYTES & (ZSTD_CWKSP_ALIGNMENT_BYTES - 1)) == 0); /* power of 2 */ - RETURN_ERROR_IF(!ZSTD_cwksp_reserve_internal_buffer_space(ws, bytesToAlign), - memory_allocation, "aligned phase - alignment initial allocation failed!"); - } { /* Align the start of the tables to 64 bytes. Use [0, 63] bytes */ - void* const alloc = ws->objectEnd; + void *const alloc = ws->objectEnd; size_t const bytesToAlign = ZSTD_cwksp_bytes_to_align_ptr(alloc, ZSTD_CWKSP_ALIGNMENT_BYTES); - void* const objectEnd = (BYTE*)alloc + bytesToAlign; + void *const objectEnd = (BYTE *) alloc + bytesToAlign; DEBUGLOG(5, "reserving table alignment addtl space: %zu", bytesToAlign); RETURN_ERROR_IF(objectEnd > ws->workspaceEnd, memory_allocation, "table phase - alignment initial allocation failed!"); @@ -302,7 +321,9 @@ ZSTD_cwksp_internal_advance_phase(ZSTD_cwksp* ws, ZSTD_cwksp_alloc_phase_e phase ws->tableEnd = objectEnd; /* table area starts being empty */ if (ws->tableValidEnd < ws->tableEnd) { ws->tableValidEnd = ws->tableEnd; - } } } + } + } + } ws->phase = phase; ZSTD_cwksp_assert_internal_consistency(ws); } @@ -314,7 +335,7 @@ ZSTD_cwksp_internal_advance_phase(ZSTD_cwksp* ws, ZSTD_cwksp_alloc_phase_e phase */ MEM_STATIC int ZSTD_cwksp_owns_buffer(const ZSTD_cwksp* ws, const void* ptr) { - return (ptr != NULL) && (ws->workspace <= ptr) && (ptr <= ws->workspaceEnd); + return (ptr != NULL) && (ws->workspace <= ptr) && (ptr < ws->workspaceEnd); } /* @@ -345,29 +366,61 @@ MEM_STATIC BYTE* ZSTD_cwksp_reserve_buffer(ZSTD_cwksp* ws, size_t bytes) /* * Reserves and returns memory sized on and aligned on ZSTD_CWKSP_ALIGNMENT_BYTES (64 bytes). + * This memory has been initialized at least once in the past. + * This doesn't mean it has been initialized this time, and it might contain data from previous + * operations. + * The main usage is for algorithms that might need read access into uninitialized memory. + * The algorithm must maintain safety under these conditions and must make sure it doesn't + * leak any of the past data (directly or in side channels). */ -MEM_STATIC void* ZSTD_cwksp_reserve_aligned(ZSTD_cwksp* ws, size_t bytes) +MEM_STATIC void* ZSTD_cwksp_reserve_aligned_init_once(ZSTD_cwksp* ws, size_t bytes) { - void* ptr = ZSTD_cwksp_reserve_internal(ws, ZSTD_cwksp_align(bytes, ZSTD_CWKSP_ALIGNMENT_BYTES), - ZSTD_cwksp_alloc_aligned); - assert(((size_t)ptr & (ZSTD_CWKSP_ALIGNMENT_BYTES-1))== 0); + size_t const alignedBytes = ZSTD_cwksp_align(bytes, ZSTD_CWKSP_ALIGNMENT_BYTES); + void* ptr = ZSTD_cwksp_reserve_internal(ws, alignedBytes, ZSTD_cwksp_alloc_aligned_init_once); + assert(((size_t)ptr & (ZSTD_CWKSP_ALIGNMENT_BYTES-1)) == 0); + if(ptr && ptr < ws->initOnceStart) { + /* We assume the memory following the current allocation is either: + * 1. Not usable as initOnce memory (end of workspace) + * 2. Another initOnce buffer that has been allocated before (and so was previously memset) + * 3. An ASAN redzone, in which case we don't want to write on it + * For these reasons it should be fine to not explicitly zero every byte up to ws->initOnceStart. + * Note that we assume here that MSAN and ASAN cannot run in the same time. */ + ZSTD_memset(ptr, 0, MIN((size_t)((U8*)ws->initOnceStart - (U8*)ptr), alignedBytes)); + ws->initOnceStart = ptr; + } + return ptr; +} + +/* + * Reserves and returns memory sized on and aligned on ZSTD_CWKSP_ALIGNMENT_BYTES (64 bytes). + */ +MEM_STATIC void* ZSTD_cwksp_reserve_aligned64(ZSTD_cwksp* ws, size_t bytes) +{ + void* const ptr = ZSTD_cwksp_reserve_internal(ws, + ZSTD_cwksp_align(bytes, ZSTD_CWKSP_ALIGNMENT_BYTES), + ZSTD_cwksp_alloc_aligned); + assert(((size_t)ptr & (ZSTD_CWKSP_ALIGNMENT_BYTES-1)) == 0); return ptr; } /* * Aligned on 64 bytes. These buffers have the special property that - * their values remain constrained, allowing us to re-use them without + * their values remain constrained, allowing us to reuse them without * memset()-ing them. */ MEM_STATIC void* ZSTD_cwksp_reserve_table(ZSTD_cwksp* ws, size_t bytes) { - const ZSTD_cwksp_alloc_phase_e phase = ZSTD_cwksp_alloc_aligned; + const ZSTD_cwksp_alloc_phase_e phase = ZSTD_cwksp_alloc_aligned_init_once; void* alloc; void* end; void* top; - if (ZSTD_isError(ZSTD_cwksp_internal_advance_phase(ws, phase))) { - return NULL; + /* We can only start allocating tables after we are done reserving space for objects at the + * start of the workspace */ + if(ws->phase < phase) { + if (ZSTD_isError(ZSTD_cwksp_internal_advance_phase(ws, phase))) { + return NULL; + } } alloc = ws->tableEnd; end = (BYTE *)alloc + bytes; @@ -387,7 +440,7 @@ MEM_STATIC void* ZSTD_cwksp_reserve_table(ZSTD_cwksp* ws, size_t bytes) assert((bytes & (ZSTD_CWKSP_ALIGNMENT_BYTES-1)) == 0); - assert(((size_t)alloc & (ZSTD_CWKSP_ALIGNMENT_BYTES-1))== 0); + assert(((size_t)alloc & (ZSTD_CWKSP_ALIGNMENT_BYTES-1)) == 0); return alloc; } @@ -421,6 +474,20 @@ MEM_STATIC void* ZSTD_cwksp_reserve_object(ZSTD_cwksp* ws, size_t bytes) return alloc; } +/* + * with alignment control + * Note : should happen only once, at workspace first initialization + */ +MEM_STATIC void* ZSTD_cwksp_reserve_object_aligned(ZSTD_cwksp* ws, size_t byteSize, size_t alignment) +{ + size_t const mask = alignment - 1; + size_t const surplus = (alignment > sizeof(void*)) ? alignment - sizeof(void*) : 0; + void* const start = ZSTD_cwksp_reserve_object(ws, byteSize + surplus); + if (start == NULL) return NULL; + if (surplus == 0) return start; + assert(ZSTD_isPower2(alignment)); + return (void*)(((size_t)start + surplus) & ~mask); +} MEM_STATIC void ZSTD_cwksp_mark_tables_dirty(ZSTD_cwksp* ws) { @@ -451,7 +518,7 @@ MEM_STATIC void ZSTD_cwksp_clean_tables(ZSTD_cwksp* ws) { assert(ws->tableValidEnd >= ws->objectEnd); assert(ws->tableValidEnd <= ws->allocStart); if (ws->tableValidEnd < ws->tableEnd) { - ZSTD_memset(ws->tableValidEnd, 0, (BYTE*)ws->tableEnd - (BYTE*)ws->tableValidEnd); + ZSTD_memset(ws->tableValidEnd, 0, (size_t)((BYTE*)ws->tableEnd - (BYTE*)ws->tableValidEnd)); } ZSTD_cwksp_mark_tables_clean(ws); } @@ -460,7 +527,8 @@ MEM_STATIC void ZSTD_cwksp_clean_tables(ZSTD_cwksp* ws) { * Invalidates table allocations. * All other allocations remain valid. */ -MEM_STATIC void ZSTD_cwksp_clear_tables(ZSTD_cwksp* ws) { +MEM_STATIC void ZSTD_cwksp_clear_tables(ZSTD_cwksp* ws) +{ DEBUGLOG(4, "cwksp: clearing tables!"); @@ -478,14 +546,23 @@ MEM_STATIC void ZSTD_cwksp_clear(ZSTD_cwksp* ws) { ws->tableEnd = ws->objectEnd; - ws->allocStart = ws->workspaceEnd; + ws->allocStart = ZSTD_cwksp_initialAllocStart(ws); ws->allocFailed = 0; - if (ws->phase > ZSTD_cwksp_alloc_buffers) { - ws->phase = ZSTD_cwksp_alloc_buffers; + if (ws->phase > ZSTD_cwksp_alloc_aligned_init_once) { + ws->phase = ZSTD_cwksp_alloc_aligned_init_once; } ZSTD_cwksp_assert_internal_consistency(ws); } +MEM_STATIC size_t ZSTD_cwksp_sizeof(const ZSTD_cwksp* ws) { + return (size_t)((BYTE*)ws->workspaceEnd - (BYTE*)ws->workspace); +} + +MEM_STATIC size_t ZSTD_cwksp_used(const ZSTD_cwksp* ws) { + return (size_t)((BYTE*)ws->tableEnd - (BYTE*)ws->workspace) + + (size_t)((BYTE*)ws->workspaceEnd - (BYTE*)ws->allocStart); +} + /* * The provided workspace takes ownership of the buffer [start, start+size). * Any existing values in the workspace are ignored (the previously managed @@ -498,6 +575,7 @@ MEM_STATIC void ZSTD_cwksp_init(ZSTD_cwksp* ws, void* start, size_t size, ZSTD_c ws->workspaceEnd = (BYTE*)start + size; ws->objectEnd = ws->workspace; ws->tableValidEnd = ws->objectEnd; + ws->initOnceStart = ZSTD_cwksp_initialAllocStart(ws); ws->phase = ZSTD_cwksp_alloc_objects; ws->isStatic = isStatic; ZSTD_cwksp_clear(ws); @@ -529,15 +607,6 @@ MEM_STATIC void ZSTD_cwksp_move(ZSTD_cwksp* dst, ZSTD_cwksp* src) { ZSTD_memset(src, 0, sizeof(ZSTD_cwksp)); } -MEM_STATIC size_t ZSTD_cwksp_sizeof(const ZSTD_cwksp* ws) { - return (size_t)((BYTE*)ws->workspaceEnd - (BYTE*)ws->workspace); -} - -MEM_STATIC size_t ZSTD_cwksp_used(const ZSTD_cwksp* ws) { - return (size_t)((BYTE*)ws->tableEnd - (BYTE*)ws->workspace) - + (size_t)((BYTE*)ws->workspaceEnd - (BYTE*)ws->allocStart); -} - MEM_STATIC int ZSTD_cwksp_reserve_failed(const ZSTD_cwksp* ws) { return ws->allocFailed; } @@ -550,17 +619,11 @@ MEM_STATIC int ZSTD_cwksp_reserve_failed(const ZSTD_cwksp* ws) { * Returns if the estimated space needed for a wksp is within an acceptable limit of the * actual amount of space used. */ -MEM_STATIC int ZSTD_cwksp_estimated_space_within_bounds(const ZSTD_cwksp* const ws, - size_t const estimatedSpace, int resizedWorkspace) { - if (resizedWorkspace) { - /* Resized/newly allocated wksp should have exact bounds */ - return ZSTD_cwksp_used(ws) == estimatedSpace; - } else { - /* Due to alignment, when reusing a workspace, we can actually consume 63 fewer or more bytes - * than estimatedSpace. See the comments in zstd_cwksp.h for details. - */ - return (ZSTD_cwksp_used(ws) >= estimatedSpace - 63) && (ZSTD_cwksp_used(ws) <= estimatedSpace + 63); - } +MEM_STATIC int ZSTD_cwksp_estimated_space_within_bounds(const ZSTD_cwksp *const ws, size_t const estimatedSpace) { + /* We have an alignment space between objects and tables between tables and buffers, so we can have up to twice + * the alignment bytes difference between estimation and actual usage */ + return (estimatedSpace - ZSTD_cwksp_slack_space_required()) <= ZSTD_cwksp_used(ws) && + ZSTD_cwksp_used(ws) <= estimatedSpace; } @@ -591,5 +654,4 @@ MEM_STATIC void ZSTD_cwksp_bump_oversized_duration( } } - #endif /* ZSTD_CWKSP_H */ diff --git a/lib/zstd/compress/zstd_double_fast.c b/lib/zstd/compress/zstd_double_fast.c index 76933dea2624..995e83f3a183 100644 --- a/lib/zstd/compress/zstd_double_fast.c +++ b/lib/zstd/compress/zstd_double_fast.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -11,8 +12,49 @@ #include "zstd_compress_internal.h" #include "zstd_double_fast.h" +#ifndef ZSTD_EXCLUDE_DFAST_BLOCK_COMPRESSOR -void ZSTD_fillDoubleHashTable(ZSTD_matchState_t* ms, +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +void ZSTD_fillDoubleHashTableForCDict(ZSTD_MatchState_t* ms, + void const* end, ZSTD_dictTableLoadMethod_e dtlm) +{ + const ZSTD_compressionParameters* const cParams = &ms->cParams; + U32* const hashLarge = ms->hashTable; + U32 const hBitsL = cParams->hashLog + ZSTD_SHORT_CACHE_TAG_BITS; + U32 const mls = cParams->minMatch; + U32* const hashSmall = ms->chainTable; + U32 const hBitsS = cParams->chainLog + ZSTD_SHORT_CACHE_TAG_BITS; + const BYTE* const base = ms->window.base; + const BYTE* ip = base + ms->nextToUpdate; + const BYTE* const iend = ((const BYTE*)end) - HASH_READ_SIZE; + const U32 fastHashFillStep = 3; + + /* Always insert every fastHashFillStep position into the hash tables. + * Insert the other positions into the large hash table if their entry + * is empty. + */ + for (; ip + fastHashFillStep - 1 <= iend; ip += fastHashFillStep) { + U32 const curr = (U32)(ip - base); + U32 i; + for (i = 0; i < fastHashFillStep; ++i) { + size_t const smHashAndTag = ZSTD_hashPtr(ip + i, hBitsS, mls); + size_t const lgHashAndTag = ZSTD_hashPtr(ip + i, hBitsL, 8); + if (i == 0) { + ZSTD_writeTaggedIndex(hashSmall, smHashAndTag, curr + i); + } + if (i == 0 || hashLarge[lgHashAndTag >> ZSTD_SHORT_CACHE_TAG_BITS] == 0) { + ZSTD_writeTaggedIndex(hashLarge, lgHashAndTag, curr + i); + } + /* Only load extra positions for ZSTD_dtlm_full */ + if (dtlm == ZSTD_dtlm_fast) + break; + } } +} + +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +void ZSTD_fillDoubleHashTableForCCtx(ZSTD_MatchState_t* ms, void const* end, ZSTD_dictTableLoadMethod_e dtlm) { const ZSTD_compressionParameters* const cParams = &ms->cParams; @@ -43,13 +85,26 @@ void ZSTD_fillDoubleHashTable(ZSTD_matchState_t* ms, /* Only load extra positions for ZSTD_dtlm_full */ if (dtlm == ZSTD_dtlm_fast) break; - } } + } } +} + +void ZSTD_fillDoubleHashTable(ZSTD_MatchState_t* ms, + const void* const end, + ZSTD_dictTableLoadMethod_e dtlm, + ZSTD_tableFillPurpose_e tfp) +{ + if (tfp == ZSTD_tfp_forCDict) { + ZSTD_fillDoubleHashTableForCDict(ms, end, dtlm); + } else { + ZSTD_fillDoubleHashTableForCCtx(ms, end, dtlm); + } } FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR size_t ZSTD_compressBlock_doubleFast_noDict_generic( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize, U32 const mls /* template */) { ZSTD_compressionParameters const* cParams = &ms->cParams; @@ -67,7 +122,7 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic( const BYTE* const iend = istart + srcSize; const BYTE* const ilimit = iend - HASH_READ_SIZE; U32 offset_1=rep[0], offset_2=rep[1]; - U32 offsetSaved = 0; + U32 offsetSaved1 = 0, offsetSaved2 = 0; size_t mLength; U32 offset; @@ -88,9 +143,14 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic( const BYTE* matchl0; /* the long match for ip */ const BYTE* matchs0; /* the short match for ip */ const BYTE* matchl1; /* the long match for ip1 */ + const BYTE* matchs0_safe; /* matchs0 or safe address */ const BYTE* ip = istart; /* the current position */ const BYTE* ip1; /* the next position */ + /* Array of ~random data, should have low probability of matching data + * we load from here instead of from tables, if matchl0/matchl1 are + * invalid indices. Used to avoid unpredictable branches. */ + const BYTE dummy[] = {0x12,0x34,0x56,0x78,0x9a,0xbc,0xde,0xf0,0xe2,0xb4}; DEBUGLOG(5, "ZSTD_compressBlock_doubleFast_noDict_generic"); @@ -100,8 +160,8 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic( U32 const current = (U32)(ip - base); U32 const windowLow = ZSTD_getLowestPrefixIndex(ms, current, cParams->windowLog); U32 const maxRep = current - windowLow; - if (offset_2 > maxRep) offsetSaved = offset_2, offset_2 = 0; - if (offset_1 > maxRep) offsetSaved = offset_1, offset_1 = 0; + if (offset_2 > maxRep) offsetSaved2 = offset_2, offset_2 = 0; + if (offset_1 > maxRep) offsetSaved1 = offset_1, offset_1 = 0; } /* Outer Loop: one iteration per match found and stored */ @@ -131,30 +191,35 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic( if ((offset_1 > 0) & (MEM_read32(ip+1-offset_1) == MEM_read32(ip+1))) { mLength = ZSTD_count(ip+1+4, ip+1+4-offset_1, iend) + 4; ip++; - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_REPCODE_1, mLength); + ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, REPCODE1_TO_OFFBASE, mLength); goto _match_stored; } hl1 = ZSTD_hashPtr(ip1, hBitsL, 8); - if (idxl0 > prefixLowestIndex) { + /* idxl0 > prefixLowestIndex is a (somewhat) unpredictable branch. + * However expression below complies into conditional move. Since + * match is unlikely and we only *branch* on idxl0 > prefixLowestIndex + * if there is a match, all branches become predictable. */ + { const BYTE* const matchl0_safe = ZSTD_selectAddr(idxl0, prefixLowestIndex, matchl0, &dummy[0]); + /* check prefix long match */ - if (MEM_read64(matchl0) == MEM_read64(ip)) { + if (MEM_read64(matchl0_safe) == MEM_read64(ip) && matchl0_safe == matchl0) { mLength = ZSTD_count(ip+8, matchl0+8, iend) + 8; offset = (U32)(ip-matchl0); while (((ip>anchor) & (matchl0>prefixLowest)) && (ip[-1] == matchl0[-1])) { ip--; matchl0--; mLength++; } /* catch up */ goto _match_found; - } - } + } } idxl1 = hashLong[hl1]; matchl1 = base + idxl1; - if (idxs0 > prefixLowestIndex) { - /* check prefix short match */ - if (MEM_read32(matchs0) == MEM_read32(ip)) { - goto _search_next_long; - } + /* Same optimization as matchl0 above */ + matchs0_safe = ZSTD_selectAddr(idxs0, prefixLowestIndex, matchs0, &dummy[0]); + + /* check prefix short match */ + if(MEM_read32(matchs0_safe) == MEM_read32(ip) && matchs0_safe == matchs0) { + goto _search_next_long; } if (ip1 >= nextStep) { @@ -175,30 +240,36 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic( } while (ip1 <= ilimit); _cleanup: + /* If offset_1 started invalid (offsetSaved1 != 0) and became valid (offset_1 != 0), + * rotate saved offsets. See comment in ZSTD_compressBlock_fast_noDict for more context. */ + offsetSaved2 = ((offsetSaved1 != 0) && (offset_1 != 0)) ? offsetSaved1 : offsetSaved2; + /* save reps for next block */ - rep[0] = offset_1 ? offset_1 : offsetSaved; - rep[1] = offset_2 ? offset_2 : offsetSaved; + rep[0] = offset_1 ? offset_1 : offsetSaved1; + rep[1] = offset_2 ? offset_2 : offsetSaved2; /* Return the last literals size */ return (size_t)(iend - anchor); _search_next_long: - /* check prefix long +1 match */ - if (idxl1 > prefixLowestIndex) { - if (MEM_read64(matchl1) == MEM_read64(ip1)) { + /* short match found: let's check for a longer one */ + mLength = ZSTD_count(ip+4, matchs0+4, iend) + 4; + offset = (U32)(ip - matchs0); + + /* check long match at +1 position */ + if ((idxl1 > prefixLowestIndex) && (MEM_read64(matchl1) == MEM_read64(ip1))) { + size_t const l1len = ZSTD_count(ip1+8, matchl1+8, iend) + 8; + if (l1len > mLength) { + /* use the long match instead */ ip = ip1; - mLength = ZSTD_count(ip+8, matchl1+8, iend) + 8; + mLength = l1len; offset = (U32)(ip-matchl1); - while (((ip>anchor) & (matchl1>prefixLowest)) && (ip[-1] == matchl1[-1])) { ip--; matchl1--; mLength++; } /* catch up */ - goto _match_found; + matchs0 = matchl1; } } - /* if no long +1 match, explore the short match we found */ - mLength = ZSTD_count(ip+4, matchs0+4, iend) + 4; - offset = (U32)(ip - matchs0); - while (((ip>anchor) & (matchs0>prefixLowest)) && (ip[-1] == matchs0[-1])) { ip--; matchs0--; mLength++; } /* catch up */ + while (((ip>anchor) & (matchs0>prefixLowest)) && (ip[-1] == matchs0[-1])) { ip--; matchs0--; mLength++; } /* complete backward */ /* fall-through */ @@ -217,7 +288,7 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic( hashLong[hl1] = (U32)(ip1 - base); } - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_OFFSET(offset), mLength); + ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, OFFSET_TO_OFFBASE(offset), mLength); _match_stored: /* match found */ @@ -243,7 +314,7 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic( U32 const tmpOff = offset_2; offset_2 = offset_1; offset_1 = tmpOff; /* swap offset_2 <=> offset_1 */ hashSmall[ZSTD_hashPtr(ip, hBitsS, mls)] = (U32)(ip-base); hashLong[ZSTD_hashPtr(ip, hBitsL, 8)] = (U32)(ip-base); - ZSTD_storeSeq(seqStore, 0, anchor, iend, STORE_REPCODE_1, rLength); + ZSTD_storeSeq(seqStore, 0, anchor, iend, REPCODE1_TO_OFFBASE, rLength); ip += rLength; anchor = ip; continue; /* faster when present ... (?) */ @@ -254,8 +325,9 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic( FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize, U32 const mls /* template */) { @@ -275,9 +347,8 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( const BYTE* const iend = istart + srcSize; const BYTE* const ilimit = iend - HASH_READ_SIZE; U32 offset_1=rep[0], offset_2=rep[1]; - U32 offsetSaved = 0; - const ZSTD_matchState_t* const dms = ms->dictMatchState; + const ZSTD_MatchState_t* const dms = ms->dictMatchState; const ZSTD_compressionParameters* const dictCParams = &dms->cParams; const U32* const dictHashLong = dms->hashTable; const U32* const dictHashSmall = dms->chainTable; @@ -286,8 +357,8 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( const BYTE* const dictStart = dictBase + dictStartIndex; const BYTE* const dictEnd = dms->window.nextSrc; const U32 dictIndexDelta = prefixLowestIndex - (U32)(dictEnd - dictBase); - const U32 dictHBitsL = dictCParams->hashLog; - const U32 dictHBitsS = dictCParams->chainLog; + const U32 dictHBitsL = dictCParams->hashLog + ZSTD_SHORT_CACHE_TAG_BITS; + const U32 dictHBitsS = dictCParams->chainLog + ZSTD_SHORT_CACHE_TAG_BITS; const U32 dictAndPrefixLength = (U32)((ip - prefixLowest) + (dictEnd - dictStart)); DEBUGLOG(5, "ZSTD_compressBlock_doubleFast_dictMatchState_generic"); @@ -295,6 +366,13 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( /* if a dictionary is attached, it must be within window range */ assert(ms->window.dictLimit + (1U << cParams->windowLog) >= endIndex); + if (ms->prefetchCDictTables) { + size_t const hashTableBytes = (((size_t)1) << dictCParams->hashLog) * sizeof(U32); + size_t const chainTableBytes = (((size_t)1) << dictCParams->chainLog) * sizeof(U32); + PREFETCH_AREA(dictHashLong, hashTableBytes); + PREFETCH_AREA(dictHashSmall, chainTableBytes); + } + /* init */ ip += (dictAndPrefixLength == 0); @@ -309,8 +387,12 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( U32 offset; size_t const h2 = ZSTD_hashPtr(ip, hBitsL, 8); size_t const h = ZSTD_hashPtr(ip, hBitsS, mls); - size_t const dictHL = ZSTD_hashPtr(ip, dictHBitsL, 8); - size_t const dictHS = ZSTD_hashPtr(ip, dictHBitsS, mls); + size_t const dictHashAndTagL = ZSTD_hashPtr(ip, dictHBitsL, 8); + size_t const dictHashAndTagS = ZSTD_hashPtr(ip, dictHBitsS, mls); + U32 const dictMatchIndexAndTagL = dictHashLong[dictHashAndTagL >> ZSTD_SHORT_CACHE_TAG_BITS]; + U32 const dictMatchIndexAndTagS = dictHashSmall[dictHashAndTagS >> ZSTD_SHORT_CACHE_TAG_BITS]; + int const dictTagsMatchL = ZSTD_comparePackedTags(dictMatchIndexAndTagL, dictHashAndTagL); + int const dictTagsMatchS = ZSTD_comparePackedTags(dictMatchIndexAndTagS, dictHashAndTagS); U32 const curr = (U32)(ip-base); U32 const matchIndexL = hashLong[h2]; U32 matchIndexS = hashSmall[h]; @@ -323,26 +405,24 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( hashLong[h2] = hashSmall[h] = curr; /* update hash tables */ /* check repcode */ - if (((U32)((prefixLowestIndex-1) - repIndex) >= 3 /* intentional underflow */) + if ((ZSTD_index_overlap_check(prefixLowestIndex, repIndex)) && (MEM_read32(repMatch) == MEM_read32(ip+1)) ) { const BYTE* repMatchEnd = repIndex < prefixLowestIndex ? dictEnd : iend; mLength = ZSTD_count_2segments(ip+1+4, repMatch+4, iend, repMatchEnd, prefixLowest) + 4; ip++; - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_REPCODE_1, mLength); + ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, REPCODE1_TO_OFFBASE, mLength); goto _match_stored; } - if (matchIndexL > prefixLowestIndex) { + if ((matchIndexL >= prefixLowestIndex) && (MEM_read64(matchLong) == MEM_read64(ip))) { /* check prefix long match */ - if (MEM_read64(matchLong) == MEM_read64(ip)) { - mLength = ZSTD_count(ip+8, matchLong+8, iend) + 8; - offset = (U32)(ip-matchLong); - while (((ip>anchor) & (matchLong>prefixLowest)) && (ip[-1] == matchLong[-1])) { ip--; matchLong--; mLength++; } /* catch up */ - goto _match_found; - } - } else { + mLength = ZSTD_count(ip+8, matchLong+8, iend) + 8; + offset = (U32)(ip-matchLong); + while (((ip>anchor) & (matchLong>prefixLowest)) && (ip[-1] == matchLong[-1])) { ip--; matchLong--; mLength++; } /* catch up */ + goto _match_found; + } else if (dictTagsMatchL) { /* check dictMatchState long match */ - U32 const dictMatchIndexL = dictHashLong[dictHL]; + U32 const dictMatchIndexL = dictMatchIndexAndTagL >> ZSTD_SHORT_CACHE_TAG_BITS; const BYTE* dictMatchL = dictBase + dictMatchIndexL; assert(dictMatchL < dictEnd); @@ -354,13 +434,13 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( } } if (matchIndexS > prefixLowestIndex) { - /* check prefix short match */ + /* short match candidate */ if (MEM_read32(match) == MEM_read32(ip)) { goto _search_next_long; } - } else { + } else if (dictTagsMatchS) { /* check dictMatchState short match */ - U32 const dictMatchIndexS = dictHashSmall[dictHS]; + U32 const dictMatchIndexS = dictMatchIndexAndTagS >> ZSTD_SHORT_CACHE_TAG_BITS; match = dictBase + dictMatchIndexS; matchIndexS = dictMatchIndexS + dictIndexDelta; @@ -375,25 +455,24 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( continue; _search_next_long: - { size_t const hl3 = ZSTD_hashPtr(ip+1, hBitsL, 8); - size_t const dictHLNext = ZSTD_hashPtr(ip+1, dictHBitsL, 8); + size_t const dictHashAndTagL3 = ZSTD_hashPtr(ip+1, dictHBitsL, 8); U32 const matchIndexL3 = hashLong[hl3]; + U32 const dictMatchIndexAndTagL3 = dictHashLong[dictHashAndTagL3 >> ZSTD_SHORT_CACHE_TAG_BITS]; + int const dictTagsMatchL3 = ZSTD_comparePackedTags(dictMatchIndexAndTagL3, dictHashAndTagL3); const BYTE* matchL3 = base + matchIndexL3; hashLong[hl3] = curr + 1; /* check prefix long +1 match */ - if (matchIndexL3 > prefixLowestIndex) { - if (MEM_read64(matchL3) == MEM_read64(ip+1)) { - mLength = ZSTD_count(ip+9, matchL3+8, iend) + 8; - ip++; - offset = (U32)(ip-matchL3); - while (((ip>anchor) & (matchL3>prefixLowest)) && (ip[-1] == matchL3[-1])) { ip--; matchL3--; mLength++; } /* catch up */ - goto _match_found; - } - } else { + if ((matchIndexL3 >= prefixLowestIndex) && (MEM_read64(matchL3) == MEM_read64(ip+1))) { + mLength = ZSTD_count(ip+9, matchL3+8, iend) + 8; + ip++; + offset = (U32)(ip-matchL3); + while (((ip>anchor) & (matchL3>prefixLowest)) && (ip[-1] == matchL3[-1])) { ip--; matchL3--; mLength++; } /* catch up */ + goto _match_found; + } else if (dictTagsMatchL3) { /* check dict long +1 match */ - U32 const dictMatchIndexL3 = dictHashLong[dictHLNext]; + U32 const dictMatchIndexL3 = dictMatchIndexAndTagL3 >> ZSTD_SHORT_CACHE_TAG_BITS; const BYTE* dictMatchL3 = dictBase + dictMatchIndexL3; assert(dictMatchL3 < dictEnd); if (dictMatchL3 > dictStart && MEM_read64(dictMatchL3) == MEM_read64(ip+1)) { @@ -419,7 +498,7 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( offset_2 = offset_1; offset_1 = offset; - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_OFFSET(offset), mLength); + ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, OFFSET_TO_OFFBASE(offset), mLength); _match_stored: /* match found */ @@ -443,12 +522,12 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( const BYTE* repMatch2 = repIndex2 < prefixLowestIndex ? dictBase + repIndex2 - dictIndexDelta : base + repIndex2; - if ( ((U32)((prefixLowestIndex-1) - (U32)repIndex2) >= 3 /* intentional overflow */) + if ( (ZSTD_index_overlap_check(prefixLowestIndex, repIndex2)) && (MEM_read32(repMatch2) == MEM_read32(ip)) ) { const BYTE* const repEnd2 = repIndex2 < prefixLowestIndex ? dictEnd : iend; size_t const repLength2 = ZSTD_count_2segments(ip+4, repMatch2+4, iend, repEnd2, prefixLowest) + 4; U32 tmpOffset = offset_2; offset_2 = offset_1; offset_1 = tmpOffset; /* swap offset_2 <=> offset_1 */ - ZSTD_storeSeq(seqStore, 0, anchor, iend, STORE_REPCODE_1, repLength2); + ZSTD_storeSeq(seqStore, 0, anchor, iend, REPCODE1_TO_OFFBASE, repLength2); hashSmall[ZSTD_hashPtr(ip, hBitsS, mls)] = current2; hashLong[ZSTD_hashPtr(ip, hBitsL, 8)] = current2; ip += repLength2; @@ -461,8 +540,8 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( } /* while (ip < ilimit) */ /* save reps for next block */ - rep[0] = offset_1 ? offset_1 : offsetSaved; - rep[1] = offset_2 ? offset_2 : offsetSaved; + rep[0] = offset_1; + rep[1] = offset_2; /* Return the last literals size */ return (size_t)(iend - anchor); @@ -470,7 +549,7 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( #define ZSTD_GEN_DFAST_FN(dictMode, mls) \ static size_t ZSTD_compressBlock_doubleFast_##dictMode##_##mls( \ - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], \ + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], \ void const* src, size_t srcSize) \ { \ return ZSTD_compressBlock_doubleFast_##dictMode##_generic(ms, seqStore, rep, src, srcSize, mls); \ @@ -488,7 +567,7 @@ ZSTD_GEN_DFAST_FN(dictMatchState, 7) size_t ZSTD_compressBlock_doubleFast( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { const U32 mls = ms->cParams.minMatch; @@ -508,7 +587,7 @@ size_t ZSTD_compressBlock_doubleFast( size_t ZSTD_compressBlock_doubleFast_dictMatchState( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { const U32 mls = ms->cParams.minMatch; @@ -527,8 +606,10 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState( } -static size_t ZSTD_compressBlock_doubleFast_extDict_generic( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +size_t ZSTD_compressBlock_doubleFast_extDict_generic( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize, U32 const mls /* template */) { @@ -579,13 +660,13 @@ static size_t ZSTD_compressBlock_doubleFast_extDict_generic( size_t mLength; hashSmall[hSmall] = hashLong[hLong] = curr; /* update hash table */ - if ((((U32)((prefixStartIndex-1) - repIndex) >= 3) /* intentional underflow : ensure repIndex doesn't overlap dict + prefix */ + if (((ZSTD_index_overlap_check(prefixStartIndex, repIndex)) & (offset_1 <= curr+1 - dictStartIndex)) /* note: we are searching at curr+1 */ && (MEM_read32(repMatch) == MEM_read32(ip+1)) ) { const BYTE* repMatchEnd = repIndex < prefixStartIndex ? dictEnd : iend; mLength = ZSTD_count_2segments(ip+1+4, repMatch+4, iend, repMatchEnd, prefixStart) + 4; ip++; - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_REPCODE_1, mLength); + ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, REPCODE1_TO_OFFBASE, mLength); } else { if ((matchLongIndex > dictStartIndex) && (MEM_read64(matchLong) == MEM_read64(ip))) { const BYTE* const matchEnd = matchLongIndex < prefixStartIndex ? dictEnd : iend; @@ -596,7 +677,7 @@ static size_t ZSTD_compressBlock_doubleFast_extDict_generic( while (((ip>anchor) & (matchLong>lowMatchPtr)) && (ip[-1] == matchLong[-1])) { ip--; matchLong--; mLength++; } /* catch up */ offset_2 = offset_1; offset_1 = offset; - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_OFFSET(offset), mLength); + ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, OFFSET_TO_OFFBASE(offset), mLength); } else if ((matchIndex > dictStartIndex) && (MEM_read32(match) == MEM_read32(ip))) { size_t const h3 = ZSTD_hashPtr(ip+1, hBitsL, 8); @@ -621,7 +702,7 @@ static size_t ZSTD_compressBlock_doubleFast_extDict_generic( } offset_2 = offset_1; offset_1 = offset; - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_OFFSET(offset), mLength); + ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, OFFSET_TO_OFFBASE(offset), mLength); } else { ip += ((ip-anchor) >> kSearchStrength) + 1; @@ -647,13 +728,13 @@ static size_t ZSTD_compressBlock_doubleFast_extDict_generic( U32 const current2 = (U32)(ip-base); U32 const repIndex2 = current2 - offset_2; const BYTE* repMatch2 = repIndex2 < prefixStartIndex ? dictBase + repIndex2 : base + repIndex2; - if ( (((U32)((prefixStartIndex-1) - repIndex2) >= 3) /* intentional overflow : ensure repIndex2 doesn't overlap dict + prefix */ + if ( ((ZSTD_index_overlap_check(prefixStartIndex, repIndex2)) & (offset_2 <= current2 - dictStartIndex)) && (MEM_read32(repMatch2) == MEM_read32(ip)) ) { const BYTE* const repEnd2 = repIndex2 < prefixStartIndex ? dictEnd : iend; size_t const repLength2 = ZSTD_count_2segments(ip+4, repMatch2+4, iend, repEnd2, prefixStart) + 4; U32 const tmpOffset = offset_2; offset_2 = offset_1; offset_1 = tmpOffset; /* swap offset_2 <=> offset_1 */ - ZSTD_storeSeq(seqStore, 0, anchor, iend, STORE_REPCODE_1, repLength2); + ZSTD_storeSeq(seqStore, 0, anchor, iend, REPCODE1_TO_OFFBASE, repLength2); hashSmall[ZSTD_hashPtr(ip, hBitsS, mls)] = current2; hashLong[ZSTD_hashPtr(ip, hBitsL, 8)] = current2; ip += repLength2; @@ -677,7 +758,7 @@ ZSTD_GEN_DFAST_FN(extDict, 6) ZSTD_GEN_DFAST_FN(extDict, 7) size_t ZSTD_compressBlock_doubleFast_extDict( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { U32 const mls = ms->cParams.minMatch; @@ -694,3 +775,5 @@ size_t ZSTD_compressBlock_doubleFast_extDict( return ZSTD_compressBlock_doubleFast_extDict_7(ms, seqStore, rep, src, srcSize); } } + +#endif /* ZSTD_EXCLUDE_DFAST_BLOCK_COMPRESSOR */ diff --git a/lib/zstd/compress/zstd_double_fast.h b/lib/zstd/compress/zstd_double_fast.h index 6822bde65a1d..011556ce56f7 100644 --- a/lib/zstd/compress/zstd_double_fast.h +++ b/lib/zstd/compress/zstd_double_fast.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -11,22 +12,32 @@ #ifndef ZSTD_DOUBLE_FAST_H #define ZSTD_DOUBLE_FAST_H - #include "../common/mem.h" /* U32 */ #include "zstd_compress_internal.h" /* ZSTD_CCtx, size_t */ -void ZSTD_fillDoubleHashTable(ZSTD_matchState_t* ms, - void const* end, ZSTD_dictTableLoadMethod_e dtlm); +#ifndef ZSTD_EXCLUDE_DFAST_BLOCK_COMPRESSOR + +void ZSTD_fillDoubleHashTable(ZSTD_MatchState_t* ms, + void const* end, ZSTD_dictTableLoadMethod_e dtlm, + ZSTD_tableFillPurpose_e tfp); + size_t ZSTD_compressBlock_doubleFast( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); size_t ZSTD_compressBlock_doubleFast_dictMatchState( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); size_t ZSTD_compressBlock_doubleFast_extDict( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); - +#define ZSTD_COMPRESSBLOCK_DOUBLEFAST ZSTD_compressBlock_doubleFast +#define ZSTD_COMPRESSBLOCK_DOUBLEFAST_DICTMATCHSTATE ZSTD_compressBlock_doubleFast_dictMatchState +#define ZSTD_COMPRESSBLOCK_DOUBLEFAST_EXTDICT ZSTD_compressBlock_doubleFast_extDict +#else +#define ZSTD_COMPRESSBLOCK_DOUBLEFAST NULL +#define ZSTD_COMPRESSBLOCK_DOUBLEFAST_DICTMATCHSTATE NULL +#define ZSTD_COMPRESSBLOCK_DOUBLEFAST_EXTDICT NULL +#endif /* ZSTD_EXCLUDE_DFAST_BLOCK_COMPRESSOR */ #endif /* ZSTD_DOUBLE_FAST_H */ diff --git a/lib/zstd/compress/zstd_fast.c b/lib/zstd/compress/zstd_fast.c index a752e6beab52..60e07e839e5f 100644 --- a/lib/zstd/compress/zstd_fast.c +++ b/lib/zstd/compress/zstd_fast.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -11,8 +12,46 @@ #include "zstd_compress_internal.h" /* ZSTD_hashPtr, ZSTD_count, ZSTD_storeSeq */ #include "zstd_fast.h" +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +void ZSTD_fillHashTableForCDict(ZSTD_MatchState_t* ms, + const void* const end, + ZSTD_dictTableLoadMethod_e dtlm) +{ + const ZSTD_compressionParameters* const cParams = &ms->cParams; + U32* const hashTable = ms->hashTable; + U32 const hBits = cParams->hashLog + ZSTD_SHORT_CACHE_TAG_BITS; + U32 const mls = cParams->minMatch; + const BYTE* const base = ms->window.base; + const BYTE* ip = base + ms->nextToUpdate; + const BYTE* const iend = ((const BYTE*)end) - HASH_READ_SIZE; + const U32 fastHashFillStep = 3; + + /* Currently, we always use ZSTD_dtlm_full for filling CDict tables. + * Feel free to remove this assert if there's a good reason! */ + assert(dtlm == ZSTD_dtlm_full); + + /* Always insert every fastHashFillStep position into the hash table. + * Insert the other positions if their hash entry is empty. + */ + for ( ; ip + fastHashFillStep < iend + 2; ip += fastHashFillStep) { + U32 const curr = (U32)(ip - base); + { size_t const hashAndTag = ZSTD_hashPtr(ip, hBits, mls); + ZSTD_writeTaggedIndex(hashTable, hashAndTag, curr); } + + if (dtlm == ZSTD_dtlm_fast) continue; + /* Only load extra positions for ZSTD_dtlm_full */ + { U32 p; + for (p = 1; p < fastHashFillStep; ++p) { + size_t const hashAndTag = ZSTD_hashPtr(ip + p, hBits, mls); + if (hashTable[hashAndTag >> ZSTD_SHORT_CACHE_TAG_BITS] == 0) { /* not yet filled */ + ZSTD_writeTaggedIndex(hashTable, hashAndTag, curr + p); + } } } } +} -void ZSTD_fillHashTable(ZSTD_matchState_t* ms, +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +void ZSTD_fillHashTableForCCtx(ZSTD_MatchState_t* ms, const void* const end, ZSTD_dictTableLoadMethod_e dtlm) { @@ -25,6 +64,10 @@ void ZSTD_fillHashTable(ZSTD_matchState_t* ms, const BYTE* const iend = ((const BYTE*)end) - HASH_READ_SIZE; const U32 fastHashFillStep = 3; + /* Currently, we always use ZSTD_dtlm_fast for filling CCtx tables. + * Feel free to remove this assert if there's a good reason! */ + assert(dtlm == ZSTD_dtlm_fast); + /* Always insert every fastHashFillStep position into the hash table. * Insert the other positions if their hash entry is empty. */ @@ -42,6 +85,60 @@ void ZSTD_fillHashTable(ZSTD_matchState_t* ms, } } } } } +void ZSTD_fillHashTable(ZSTD_MatchState_t* ms, + const void* const end, + ZSTD_dictTableLoadMethod_e dtlm, + ZSTD_tableFillPurpose_e tfp) +{ + if (tfp == ZSTD_tfp_forCDict) { + ZSTD_fillHashTableForCDict(ms, end, dtlm); + } else { + ZSTD_fillHashTableForCCtx(ms, end, dtlm); + } +} + + +typedef int (*ZSTD_match4Found) (const BYTE* currentPtr, const BYTE* matchAddress, U32 matchIdx, U32 idxLowLimit); + +static int +ZSTD_match4Found_cmov(const BYTE* currentPtr, const BYTE* matchAddress, U32 matchIdx, U32 idxLowLimit) +{ + /* Array of ~random data, should have low probability of matching data. + * Load from here if the index is invalid. + * Used to avoid unpredictable branches. */ + static const BYTE dummy[] = {0x12,0x34,0x56,0x78}; + + /* currentIdx >= lowLimit is a (somewhat) unpredictable branch. + * However expression below compiles into conditional move. + */ + const BYTE* mvalAddr = ZSTD_selectAddr(matchIdx, idxLowLimit, matchAddress, dummy); + /* Note: this used to be written as : return test1 && test2; + * Unfortunately, once inlined, these tests become branches, + * in which case it becomes critical that they are executed in the right order (test1 then test2). + * So we have to write these tests in a specific manner to ensure their ordering. + */ + if (MEM_read32(currentPtr) != MEM_read32(mvalAddr)) return 0; + /* force ordering of these tests, which matters once the function is inlined, as they become branches */ + __asm__(""); + return matchIdx >= idxLowLimit; +} + +static int +ZSTD_match4Found_branch(const BYTE* currentPtr, const BYTE* matchAddress, U32 matchIdx, U32 idxLowLimit) +{ + /* using a branch instead of a cmov, + * because it's faster in scenarios where matchIdx >= idxLowLimit is generally true, + * aka almost all candidates are within range */ + U32 mval; + if (matchIdx >= idxLowLimit) { + mval = MEM_read32(matchAddress); + } else { + mval = MEM_read32(currentPtr) ^ 1; /* guaranteed to not match. */ + } + + return (MEM_read32(currentPtr) == mval); +} + /* * If you squint hard enough (and ignore repcodes), the search operation at any @@ -89,17 +186,17 @@ void ZSTD_fillHashTable(ZSTD_matchState_t* ms, * * This is also the work we do at the beginning to enter the loop initially. */ -FORCE_INLINE_TEMPLATE size_t -ZSTD_compressBlock_fast_noDict_generic( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +size_t ZSTD_compressBlock_fast_noDict_generic( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize, - U32 const mls, U32 const hasStep) + U32 const mls, int useCmov) { const ZSTD_compressionParameters* const cParams = &ms->cParams; U32* const hashTable = ms->hashTable; U32 const hlog = cParams->hashLog; - /* support stepSize of 0 */ - size_t const stepSize = hasStep ? (cParams->targetLength + !(cParams->targetLength) + 1) : 2; + size_t const stepSize = cParams->targetLength + !(cParams->targetLength) + 1; /* min 2 */ const BYTE* const base = ms->window.base; const BYTE* const istart = (const BYTE*)src; const U32 endIndex = (U32)((size_t)(istart - base) + srcSize); @@ -117,12 +214,11 @@ ZSTD_compressBlock_fast_noDict_generic( U32 rep_offset1 = rep[0]; U32 rep_offset2 = rep[1]; - U32 offsetSaved = 0; + U32 offsetSaved1 = 0, offsetSaved2 = 0; size_t hash0; /* hash for ip0 */ size_t hash1; /* hash for ip1 */ - U32 idx; /* match idx for ip0 */ - U32 mval; /* src value at match idx */ + U32 matchIdx; /* match idx for ip0 */ U32 offcode; const BYTE* match0; @@ -135,14 +231,15 @@ ZSTD_compressBlock_fast_noDict_generic( size_t step; const BYTE* nextStep; const size_t kStepIncr = (1 << (kSearchStrength - 1)); + const ZSTD_match4Found matchFound = useCmov ? ZSTD_match4Found_cmov : ZSTD_match4Found_branch; DEBUGLOG(5, "ZSTD_compressBlock_fast_generic"); ip0 += (ip0 == prefixStart); { U32 const curr = (U32)(ip0 - base); U32 const windowLow = ZSTD_getLowestPrefixIndex(ms, curr, cParams->windowLog); U32 const maxRep = curr - windowLow; - if (rep_offset2 > maxRep) offsetSaved = rep_offset2, rep_offset2 = 0; - if (rep_offset1 > maxRep) offsetSaved = rep_offset1, rep_offset1 = 0; + if (rep_offset2 > maxRep) offsetSaved2 = rep_offset2, rep_offset2 = 0; + if (rep_offset1 > maxRep) offsetSaved1 = rep_offset1, rep_offset1 = 0; } /* start each op */ @@ -163,7 +260,7 @@ ZSTD_compressBlock_fast_noDict_generic( hash0 = ZSTD_hashPtr(ip0, hlog, mls); hash1 = ZSTD_hashPtr(ip1, hlog, mls); - idx = hashTable[hash0]; + matchIdx = hashTable[hash0]; do { /* load repcode match for ip[2]*/ @@ -180,26 +277,28 @@ ZSTD_compressBlock_fast_noDict_generic( mLength = ip0[-1] == match0[-1]; ip0 -= mLength; match0 -= mLength; - offcode = STORE_REPCODE_1; + offcode = REPCODE1_TO_OFFBASE; mLength += 4; + + /* Write next hash table entry: it's already calculated. + * This write is known to be safe because ip1 is before the + * repcode (ip2). */ + hashTable[hash1] = (U32)(ip1 - base); + goto _match; } - /* load match for ip[0] */ - if (idx >= prefixStartIndex) { - mval = MEM_read32(base + idx); - } else { - mval = MEM_read32(ip0) ^ 1; /* guaranteed to not match. */ - } + if (matchFound(ip0, base + matchIdx, matchIdx, prefixStartIndex)) { + /* Write next hash table entry (it's already calculated). + * This write is known to be safe because the ip1 == ip0 + 1, + * so searching will resume after ip1 */ + hashTable[hash1] = (U32)(ip1 - base); - /* check match at ip[0] */ - if (MEM_read32(ip0) == mval) { - /* found a match! */ goto _offset; } /* lookup ip[1] */ - idx = hashTable[hash1]; + matchIdx = hashTable[hash1]; /* hash ip[2] */ hash0 = hash1; @@ -214,21 +313,19 @@ ZSTD_compressBlock_fast_noDict_generic( current0 = (U32)(ip0 - base); hashTable[hash0] = current0; - /* load match for ip[0] */ - if (idx >= prefixStartIndex) { - mval = MEM_read32(base + idx); - } else { - mval = MEM_read32(ip0) ^ 1; /* guaranteed to not match. */ - } - - /* check match at ip[0] */ - if (MEM_read32(ip0) == mval) { - /* found a match! */ + if (matchFound(ip0, base + matchIdx, matchIdx, prefixStartIndex)) { + /* Write next hash table entry, since it's already calculated */ + if (step <= 4) { + /* Avoid writing an index if it's >= position where search will resume. + * The minimum possible match has length 4, so search can resume at ip0 + 4. + */ + hashTable[hash1] = (U32)(ip1 - base); + } goto _offset; } /* lookup ip[1] */ - idx = hashTable[hash1]; + matchIdx = hashTable[hash1]; /* hash ip[2] */ hash0 = hash1; @@ -250,13 +347,28 @@ ZSTD_compressBlock_fast_noDict_generic( } while (ip3 < ilimit); _cleanup: - /* Note that there are probably still a couple positions we could search. + /* Note that there are probably still a couple positions one could search. * However, it seems to be a meaningful performance hit to try to search * them. So let's not. */ + /* When the repcodes are outside of the prefix, we set them to zero before the loop. + * When the offsets are still zero, we need to restore them after the block to have a correct + * repcode history. If only one offset was invalid, it is easy. The tricky case is when both + * offsets were invalid. We need to figure out which offset to refill with. + * - If both offsets are zero they are in the same order. + * - If both offsets are non-zero, we won't restore the offsets from `offsetSaved[12]`. + * - If only one is zero, we need to decide which offset to restore. + * - If rep_offset1 is non-zero, then rep_offset2 must be offsetSaved1. + * - It is impossible for rep_offset2 to be non-zero. + * + * So if rep_offset1 started invalid (offsetSaved1 != 0) and became valid (rep_offset1 != 0), then + * set rep[0] = rep_offset1 and rep[1] = offsetSaved1. + */ + offsetSaved2 = ((offsetSaved1 != 0) && (rep_offset1 != 0)) ? offsetSaved1 : offsetSaved2; + /* save reps for next block */ - rep[0] = rep_offset1 ? rep_offset1 : offsetSaved; - rep[1] = rep_offset2 ? rep_offset2 : offsetSaved; + rep[0] = rep_offset1 ? rep_offset1 : offsetSaved1; + rep[1] = rep_offset2 ? rep_offset2 : offsetSaved2; /* Return the last literals size */ return (size_t)(iend - anchor); @@ -264,10 +376,10 @@ ZSTD_compressBlock_fast_noDict_generic( _offset: /* Requires: ip0, idx */ /* Compute the offset code. */ - match0 = base + idx; + match0 = base + matchIdx; rep_offset2 = rep_offset1; rep_offset1 = (U32)(ip0-match0); - offcode = STORE_OFFSET(rep_offset1); + offcode = OFFSET_TO_OFFBASE(rep_offset1); mLength = 4; /* Count the backwards match length. */ @@ -287,11 +399,6 @@ ZSTD_compressBlock_fast_noDict_generic( ip0 += mLength; anchor = ip0; - /* write next hash table entry */ - if (ip1 < ip0) { - hashTable[hash1] = (U32)(ip1 - base); - } - /* Fill table and check for immediate repcode. */ if (ip0 <= ilimit) { /* Fill Table */ @@ -306,7 +413,7 @@ ZSTD_compressBlock_fast_noDict_generic( { U32 const tmpOff = rep_offset2; rep_offset2 = rep_offset1; rep_offset1 = tmpOff; } /* swap rep_offset2 <=> rep_offset1 */ hashTable[ZSTD_hashPtr(ip0, hlog, mls)] = (U32)(ip0-base); ip0 += rLength; - ZSTD_storeSeq(seqStore, 0 /*litLen*/, anchor, iend, STORE_REPCODE_1, rLength); + ZSTD_storeSeq(seqStore, 0 /*litLen*/, anchor, iend, REPCODE1_TO_OFFBASE, rLength); anchor = ip0; continue; /* faster when present (confirmed on gcc-8) ... (?) */ } } } @@ -314,12 +421,12 @@ ZSTD_compressBlock_fast_noDict_generic( goto _start; } -#define ZSTD_GEN_FAST_FN(dictMode, mls, step) \ - static size_t ZSTD_compressBlock_fast_##dictMode##_##mls##_##step( \ - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], \ +#define ZSTD_GEN_FAST_FN(dictMode, mml, cmov) \ + static size_t ZSTD_compressBlock_fast_##dictMode##_##mml##_##cmov( \ + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], \ void const* src, size_t srcSize) \ { \ - return ZSTD_compressBlock_fast_##dictMode##_generic(ms, seqStore, rep, src, srcSize, mls, step); \ + return ZSTD_compressBlock_fast_##dictMode##_generic(ms, seqStore, rep, src, srcSize, mml, cmov); \ } ZSTD_GEN_FAST_FN(noDict, 4, 1) @@ -333,13 +440,15 @@ ZSTD_GEN_FAST_FN(noDict, 6, 0) ZSTD_GEN_FAST_FN(noDict, 7, 0) size_t ZSTD_compressBlock_fast( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - U32 const mls = ms->cParams.minMatch; + U32 const mml = ms->cParams.minMatch; + /* use cmov when "candidate in range" branch is likely unpredictable */ + int const useCmov = ms->cParams.windowLog < 19; assert(ms->dictMatchState == NULL); - if (ms->cParams.targetLength > 1) { - switch(mls) + if (useCmov) { + switch(mml) { default: /* includes case 3 */ case 4 : @@ -352,7 +461,8 @@ size_t ZSTD_compressBlock_fast( return ZSTD_compressBlock_fast_noDict_7_1(ms, seqStore, rep, src, srcSize); } } else { - switch(mls) + /* use a branch instead */ + switch(mml) { default: /* includes case 3 */ case 4 : @@ -364,13 +474,13 @@ size_t ZSTD_compressBlock_fast( case 7 : return ZSTD_compressBlock_fast_noDict_7_0(ms, seqStore, rep, src, srcSize); } - } } FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR size_t ZSTD_compressBlock_fast_dictMatchState_generic( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize, U32 const mls, U32 const hasStep) { const ZSTD_compressionParameters* const cParams = &ms->cParams; @@ -380,16 +490,16 @@ size_t ZSTD_compressBlock_fast_dictMatchState_generic( U32 const stepSize = cParams->targetLength + !(cParams->targetLength); const BYTE* const base = ms->window.base; const BYTE* const istart = (const BYTE*)src; - const BYTE* ip = istart; + const BYTE* ip0 = istart; + const BYTE* ip1 = ip0 + stepSize; /* we assert below that stepSize >= 1 */ const BYTE* anchor = istart; const U32 prefixStartIndex = ms->window.dictLimit; const BYTE* const prefixStart = base + prefixStartIndex; const BYTE* const iend = istart + srcSize; const BYTE* const ilimit = iend - HASH_READ_SIZE; U32 offset_1=rep[0], offset_2=rep[1]; - U32 offsetSaved = 0; - const ZSTD_matchState_t* const dms = ms->dictMatchState; + const ZSTD_MatchState_t* const dms = ms->dictMatchState; const ZSTD_compressionParameters* const dictCParams = &dms->cParams ; const U32* const dictHashTable = dms->hashTable; const U32 dictStartIndex = dms->window.dictLimit; @@ -397,13 +507,13 @@ size_t ZSTD_compressBlock_fast_dictMatchState_generic( const BYTE* const dictStart = dictBase + dictStartIndex; const BYTE* const dictEnd = dms->window.nextSrc; const U32 dictIndexDelta = prefixStartIndex - (U32)(dictEnd - dictBase); - const U32 dictAndPrefixLength = (U32)(ip - prefixStart + dictEnd - dictStart); - const U32 dictHLog = dictCParams->hashLog; + const U32 dictAndPrefixLength = (U32)(istart - prefixStart + dictEnd - dictStart); + const U32 dictHBits = dictCParams->hashLog + ZSTD_SHORT_CACHE_TAG_BITS; /* if a dictionary is still attached, it necessarily means that * it is within window size. So we just check it. */ const U32 maxDistance = 1U << cParams->windowLog; - const U32 endIndex = (U32)((size_t)(ip - base) + srcSize); + const U32 endIndex = (U32)((size_t)(istart - base) + srcSize); assert(endIndex - prefixStartIndex <= maxDistance); (void)maxDistance; (void)endIndex; /* these variables are not used when assert() is disabled */ @@ -413,106 +523,154 @@ size_t ZSTD_compressBlock_fast_dictMatchState_generic( * when translating a dict index into a local index */ assert(prefixStartIndex >= (U32)(dictEnd - dictBase)); + if (ms->prefetchCDictTables) { + size_t const hashTableBytes = (((size_t)1) << dictCParams->hashLog) * sizeof(U32); + PREFETCH_AREA(dictHashTable, hashTableBytes); + } + /* init */ DEBUGLOG(5, "ZSTD_compressBlock_fast_dictMatchState_generic"); - ip += (dictAndPrefixLength == 0); + ip0 += (dictAndPrefixLength == 0); /* dictMatchState repCode checks don't currently handle repCode == 0 * disabling. */ assert(offset_1 <= dictAndPrefixLength); assert(offset_2 <= dictAndPrefixLength); - /* Main Search Loop */ - while (ip < ilimit) { /* < instead of <=, because repcode check at (ip+1) */ + /* Outer search loop */ + assert(stepSize >= 1); + while (ip1 <= ilimit) { /* repcode check at (ip0 + 1) is safe because ip0 < ip1 */ size_t mLength; - size_t const h = ZSTD_hashPtr(ip, hlog, mls); - U32 const curr = (U32)(ip-base); - U32 const matchIndex = hashTable[h]; - const BYTE* match = base + matchIndex; - const U32 repIndex = curr + 1 - offset_1; - const BYTE* repMatch = (repIndex < prefixStartIndex) ? - dictBase + (repIndex - dictIndexDelta) : - base + repIndex; - hashTable[h] = curr; /* update hash table */ - - if ( ((U32)((prefixStartIndex-1) - repIndex) >= 3) /* intentional underflow : ensure repIndex isn't overlapping dict + prefix */ - && (MEM_read32(repMatch) == MEM_read32(ip+1)) ) { - const BYTE* const repMatchEnd = repIndex < prefixStartIndex ? dictEnd : iend; - mLength = ZSTD_count_2segments(ip+1+4, repMatch+4, iend, repMatchEnd, prefixStart) + 4; - ip++; - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_REPCODE_1, mLength); - } else if ( (matchIndex <= prefixStartIndex) ) { - size_t const dictHash = ZSTD_hashPtr(ip, dictHLog, mls); - U32 const dictMatchIndex = dictHashTable[dictHash]; - const BYTE* dictMatch = dictBase + dictMatchIndex; - if (dictMatchIndex <= dictStartIndex || - MEM_read32(dictMatch) != MEM_read32(ip)) { - assert(stepSize >= 1); - ip += ((ip-anchor) >> kSearchStrength) + stepSize; - continue; - } else { - /* found a dict match */ - U32 const offset = (U32)(curr-dictMatchIndex-dictIndexDelta); - mLength = ZSTD_count_2segments(ip+4, dictMatch+4, iend, dictEnd, prefixStart) + 4; - while (((ip>anchor) & (dictMatch>dictStart)) - && (ip[-1] == dictMatch[-1])) { - ip--; dictMatch--; mLength++; + size_t hash0 = ZSTD_hashPtr(ip0, hlog, mls); + + size_t const dictHashAndTag0 = ZSTD_hashPtr(ip0, dictHBits, mls); + U32 dictMatchIndexAndTag = dictHashTable[dictHashAndTag0 >> ZSTD_SHORT_CACHE_TAG_BITS]; + int dictTagsMatch = ZSTD_comparePackedTags(dictMatchIndexAndTag, dictHashAndTag0); + + U32 matchIndex = hashTable[hash0]; + U32 curr = (U32)(ip0 - base); + size_t step = stepSize; + const size_t kStepIncr = 1 << kSearchStrength; + const BYTE* nextStep = ip0 + kStepIncr; + + /* Inner search loop */ + while (1) { + const BYTE* match = base + matchIndex; + const U32 repIndex = curr + 1 - offset_1; + const BYTE* repMatch = (repIndex < prefixStartIndex) ? + dictBase + (repIndex - dictIndexDelta) : + base + repIndex; + const size_t hash1 = ZSTD_hashPtr(ip1, hlog, mls); + size_t const dictHashAndTag1 = ZSTD_hashPtr(ip1, dictHBits, mls); + hashTable[hash0] = curr; /* update hash table */ + + if ((ZSTD_index_overlap_check(prefixStartIndex, repIndex)) + && (MEM_read32(repMatch) == MEM_read32(ip0 + 1))) { + const BYTE* const repMatchEnd = repIndex < prefixStartIndex ? dictEnd : iend; + mLength = ZSTD_count_2segments(ip0 + 1 + 4, repMatch + 4, iend, repMatchEnd, prefixStart) + 4; + ip0++; + ZSTD_storeSeq(seqStore, (size_t) (ip0 - anchor), anchor, iend, REPCODE1_TO_OFFBASE, mLength); + break; + } + + if (dictTagsMatch) { + /* Found a possible dict match */ + const U32 dictMatchIndex = dictMatchIndexAndTag >> ZSTD_SHORT_CACHE_TAG_BITS; + const BYTE* dictMatch = dictBase + dictMatchIndex; + if (dictMatchIndex > dictStartIndex && + MEM_read32(dictMatch) == MEM_read32(ip0)) { + /* To replicate extDict parse behavior, we only use dict matches when the normal matchIndex is invalid */ + if (matchIndex <= prefixStartIndex) { + U32 const offset = (U32) (curr - dictMatchIndex - dictIndexDelta); + mLength = ZSTD_count_2segments(ip0 + 4, dictMatch + 4, iend, dictEnd, prefixStart) + 4; + while (((ip0 > anchor) & (dictMatch > dictStart)) + && (ip0[-1] == dictMatch[-1])) { + ip0--; + dictMatch--; + mLength++; + } /* catch up */ + offset_2 = offset_1; + offset_1 = offset; + ZSTD_storeSeq(seqStore, (size_t) (ip0 - anchor), anchor, iend, OFFSET_TO_OFFBASE(offset), mLength); + break; + } + } + } + + if (ZSTD_match4Found_cmov(ip0, match, matchIndex, prefixStartIndex)) { + /* found a regular match of size >= 4 */ + U32 const offset = (U32) (ip0 - match); + mLength = ZSTD_count(ip0 + 4, match + 4, iend) + 4; + while (((ip0 > anchor) & (match > prefixStart)) + && (ip0[-1] == match[-1])) { + ip0--; + match--; + mLength++; } /* catch up */ offset_2 = offset_1; offset_1 = offset; - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_OFFSET(offset), mLength); + ZSTD_storeSeq(seqStore, (size_t) (ip0 - anchor), anchor, iend, OFFSET_TO_OFFBASE(offset), mLength); + break; } - } else if (MEM_read32(match) != MEM_read32(ip)) { - /* it's not a match, and we're not going to check the dictionary */ - assert(stepSize >= 1); - ip += ((ip-anchor) >> kSearchStrength) + stepSize; - continue; - } else { - /* found a regular match */ - U32 const offset = (U32)(ip-match); - mLength = ZSTD_count(ip+4, match+4, iend) + 4; - while (((ip>anchor) & (match>prefixStart)) - && (ip[-1] == match[-1])) { ip--; match--; mLength++; } /* catch up */ - offset_2 = offset_1; - offset_1 = offset; - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_OFFSET(offset), mLength); - } + + /* Prepare for next iteration */ + dictMatchIndexAndTag = dictHashTable[dictHashAndTag1 >> ZSTD_SHORT_CACHE_TAG_BITS]; + dictTagsMatch = ZSTD_comparePackedTags(dictMatchIndexAndTag, dictHashAndTag1); + matchIndex = hashTable[hash1]; + + if (ip1 >= nextStep) { + step++; + nextStep += kStepIncr; + } + ip0 = ip1; + ip1 = ip1 + step; + if (ip1 > ilimit) goto _cleanup; + + curr = (U32)(ip0 - base); + hash0 = hash1; + } /* end inner search loop */ /* match found */ - ip += mLength; - anchor = ip; + assert(mLength); + ip0 += mLength; + anchor = ip0; - if (ip <= ilimit) { + if (ip0 <= ilimit) { /* Fill Table */ assert(base+curr+2 > istart); /* check base overflow */ hashTable[ZSTD_hashPtr(base+curr+2, hlog, mls)] = curr+2; /* here because curr+2 could be > iend-8 */ - hashTable[ZSTD_hashPtr(ip-2, hlog, mls)] = (U32)(ip-2-base); + hashTable[ZSTD_hashPtr(ip0-2, hlog, mls)] = (U32)(ip0-2-base); /* check immediate repcode */ - while (ip <= ilimit) { - U32 const current2 = (U32)(ip-base); + while (ip0 <= ilimit) { + U32 const current2 = (U32)(ip0-base); U32 const repIndex2 = current2 - offset_2; const BYTE* repMatch2 = repIndex2 < prefixStartIndex ? dictBase - dictIndexDelta + repIndex2 : base + repIndex2; - if ( ((U32)((prefixStartIndex-1) - (U32)repIndex2) >= 3 /* intentional overflow */) - && (MEM_read32(repMatch2) == MEM_read32(ip)) ) { + if ( (ZSTD_index_overlap_check(prefixStartIndex, repIndex2)) + && (MEM_read32(repMatch2) == MEM_read32(ip0))) { const BYTE* const repEnd2 = repIndex2 < prefixStartIndex ? dictEnd : iend; - size_t const repLength2 = ZSTD_count_2segments(ip+4, repMatch2+4, iend, repEnd2, prefixStart) + 4; + size_t const repLength2 = ZSTD_count_2segments(ip0+4, repMatch2+4, iend, repEnd2, prefixStart) + 4; U32 tmpOffset = offset_2; offset_2 = offset_1; offset_1 = tmpOffset; /* swap offset_2 <=> offset_1 */ - ZSTD_storeSeq(seqStore, 0, anchor, iend, STORE_REPCODE_1, repLength2); - hashTable[ZSTD_hashPtr(ip, hlog, mls)] = current2; - ip += repLength2; - anchor = ip; + ZSTD_storeSeq(seqStore, 0, anchor, iend, REPCODE1_TO_OFFBASE, repLength2); + hashTable[ZSTD_hashPtr(ip0, hlog, mls)] = current2; + ip0 += repLength2; + anchor = ip0; continue; } break; } } + + /* Prepare for next iteration */ + assert(ip0 == anchor); + ip1 = ip0 + stepSize; } +_cleanup: /* save reps for next block */ - rep[0] = offset_1 ? offset_1 : offsetSaved; - rep[1] = offset_2 ? offset_2 : offsetSaved; + rep[0] = offset_1; + rep[1] = offset_2; /* Return the last literals size */ return (size_t)(iend - anchor); @@ -525,7 +683,7 @@ ZSTD_GEN_FAST_FN(dictMatchState, 6, 0) ZSTD_GEN_FAST_FN(dictMatchState, 7, 0) size_t ZSTD_compressBlock_fast_dictMatchState( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { U32 const mls = ms->cParams.minMatch; @@ -545,19 +703,20 @@ size_t ZSTD_compressBlock_fast_dictMatchState( } -static size_t ZSTD_compressBlock_fast_extDict_generic( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +size_t ZSTD_compressBlock_fast_extDict_generic( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize, U32 const mls, U32 const hasStep) { const ZSTD_compressionParameters* const cParams = &ms->cParams; U32* const hashTable = ms->hashTable; U32 const hlog = cParams->hashLog; /* support stepSize of 0 */ - U32 const stepSize = cParams->targetLength + !(cParams->targetLength); + size_t const stepSize = cParams->targetLength + !(cParams->targetLength) + 1; const BYTE* const base = ms->window.base; const BYTE* const dictBase = ms->window.dictBase; const BYTE* const istart = (const BYTE*)src; - const BYTE* ip = istart; const BYTE* anchor = istart; const U32 endIndex = (U32)((size_t)(istart - base) + srcSize); const U32 lowLimit = ZSTD_getLowestMatchIndex(ms, endIndex, cParams->windowLog); @@ -570,6 +729,28 @@ static size_t ZSTD_compressBlock_fast_extDict_generic( const BYTE* const iend = istart + srcSize; const BYTE* const ilimit = iend - 8; U32 offset_1=rep[0], offset_2=rep[1]; + U32 offsetSaved1 = 0, offsetSaved2 = 0; + + const BYTE* ip0 = istart; + const BYTE* ip1; + const BYTE* ip2; + const BYTE* ip3; + U32 current0; + + + size_t hash0; /* hash for ip0 */ + size_t hash1; /* hash for ip1 */ + U32 idx; /* match idx for ip0 */ + const BYTE* idxBase; /* base pointer for idx */ + + U32 offcode; + const BYTE* match0; + size_t mLength; + const BYTE* matchEnd = 0; /* initialize to avoid warning, assert != 0 later */ + + size_t step; + const BYTE* nextStep; + const size_t kStepIncr = (1 << (kSearchStrength - 1)); (void)hasStep; /* not currently specialized on whether it's accelerated */ @@ -579,75 +760,202 @@ static size_t ZSTD_compressBlock_fast_extDict_generic( if (prefixStartIndex == dictStartIndex) return ZSTD_compressBlock_fast(ms, seqStore, rep, src, srcSize); - /* Search Loop */ - while (ip < ilimit) { /* < instead of <=, because (ip+1) */ - const size_t h = ZSTD_hashPtr(ip, hlog, mls); - const U32 matchIndex = hashTable[h]; - const BYTE* const matchBase = matchIndex < prefixStartIndex ? dictBase : base; - const BYTE* match = matchBase + matchIndex; - const U32 curr = (U32)(ip-base); - const U32 repIndex = curr + 1 - offset_1; - const BYTE* const repBase = repIndex < prefixStartIndex ? dictBase : base; - const BYTE* const repMatch = repBase + repIndex; - hashTable[h] = curr; /* update hash table */ - DEBUGLOG(7, "offset_1 = %u , curr = %u", offset_1, curr); - - if ( ( ((U32)((prefixStartIndex-1) - repIndex) >= 3) /* intentional underflow */ - & (offset_1 <= curr+1 - dictStartIndex) ) /* note: we are searching at curr+1 */ - && (MEM_read32(repMatch) == MEM_read32(ip+1)) ) { - const BYTE* const repMatchEnd = repIndex < prefixStartIndex ? dictEnd : iend; - size_t const rLength = ZSTD_count_2segments(ip+1 +4, repMatch +4, iend, repMatchEnd, prefixStart) + 4; - ip++; - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_REPCODE_1, rLength); - ip += rLength; - anchor = ip; - } else { - if ( (matchIndex < dictStartIndex) || - (MEM_read32(match) != MEM_read32(ip)) ) { - assert(stepSize >= 1); - ip += ((ip-anchor) >> kSearchStrength) + stepSize; - continue; + { U32 const curr = (U32)(ip0 - base); + U32 const maxRep = curr - dictStartIndex; + if (offset_2 >= maxRep) offsetSaved2 = offset_2, offset_2 = 0; + if (offset_1 >= maxRep) offsetSaved1 = offset_1, offset_1 = 0; + } + + /* start each op */ +_start: /* Requires: ip0 */ + + step = stepSize; + nextStep = ip0 + kStepIncr; + + /* calculate positions, ip0 - anchor == 0, so we skip step calc */ + ip1 = ip0 + 1; + ip2 = ip0 + step; + ip3 = ip2 + 1; + + if (ip3 >= ilimit) { + goto _cleanup; + } + + hash0 = ZSTD_hashPtr(ip0, hlog, mls); + hash1 = ZSTD_hashPtr(ip1, hlog, mls); + + idx = hashTable[hash0]; + idxBase = idx < prefixStartIndex ? dictBase : base; + + do { + { /* load repcode match for ip[2] */ + U32 const current2 = (U32)(ip2 - base); + U32 const repIndex = current2 - offset_1; + const BYTE* const repBase = repIndex < prefixStartIndex ? dictBase : base; + U32 rval; + if ( ((U32)(prefixStartIndex - repIndex) >= 4) /* intentional underflow */ + & (offset_1 > 0) ) { + rval = MEM_read32(repBase + repIndex); + } else { + rval = MEM_read32(ip2) ^ 1; /* guaranteed to not match. */ } - { const BYTE* const matchEnd = matchIndex < prefixStartIndex ? dictEnd : iend; - const BYTE* const lowMatchPtr = matchIndex < prefixStartIndex ? dictStart : prefixStart; - U32 const offset = curr - matchIndex; - size_t mLength = ZSTD_count_2segments(ip+4, match+4, iend, matchEnd, prefixStart) + 4; - while (((ip>anchor) & (match>lowMatchPtr)) && (ip[-1] == match[-1])) { ip--; match--; mLength++; } /* catch up */ - offset_2 = offset_1; offset_1 = offset; /* update offset history */ - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_OFFSET(offset), mLength); - ip += mLength; - anchor = ip; + + /* write back hash table entry */ + current0 = (U32)(ip0 - base); + hashTable[hash0] = current0; + + /* check repcode at ip[2] */ + if (MEM_read32(ip2) == rval) { + ip0 = ip2; + match0 = repBase + repIndex; + matchEnd = repIndex < prefixStartIndex ? dictEnd : iend; + assert((match0 != prefixStart) & (match0 != dictStart)); + mLength = ip0[-1] == match0[-1]; + ip0 -= mLength; + match0 -= mLength; + offcode = REPCODE1_TO_OFFBASE; + mLength += 4; + goto _match; } } - if (ip <= ilimit) { - /* Fill Table */ - hashTable[ZSTD_hashPtr(base+curr+2, hlog, mls)] = curr+2; - hashTable[ZSTD_hashPtr(ip-2, hlog, mls)] = (U32)(ip-2-base); - /* check immediate repcode */ - while (ip <= ilimit) { - U32 const current2 = (U32)(ip-base); - U32 const repIndex2 = current2 - offset_2; - const BYTE* const repMatch2 = repIndex2 < prefixStartIndex ? dictBase + repIndex2 : base + repIndex2; - if ( (((U32)((prefixStartIndex-1) - repIndex2) >= 3) & (offset_2 <= curr - dictStartIndex)) /* intentional overflow */ - && (MEM_read32(repMatch2) == MEM_read32(ip)) ) { - const BYTE* const repEnd2 = repIndex2 < prefixStartIndex ? dictEnd : iend; - size_t const repLength2 = ZSTD_count_2segments(ip+4, repMatch2+4, iend, repEnd2, prefixStart) + 4; - { U32 const tmpOffset = offset_2; offset_2 = offset_1; offset_1 = tmpOffset; } /* swap offset_2 <=> offset_1 */ - ZSTD_storeSeq(seqStore, 0 /*litlen*/, anchor, iend, STORE_REPCODE_1, repLength2); - hashTable[ZSTD_hashPtr(ip, hlog, mls)] = current2; - ip += repLength2; - anchor = ip; - continue; - } - break; - } } } + { /* load match for ip[0] */ + U32 const mval = idx >= dictStartIndex ? + MEM_read32(idxBase + idx) : + MEM_read32(ip0) ^ 1; /* guaranteed not to match */ + + /* check match at ip[0] */ + if (MEM_read32(ip0) == mval) { + /* found a match! */ + goto _offset; + } } + + /* lookup ip[1] */ + idx = hashTable[hash1]; + idxBase = idx < prefixStartIndex ? dictBase : base; + + /* hash ip[2] */ + hash0 = hash1; + hash1 = ZSTD_hashPtr(ip2, hlog, mls); + + /* advance to next positions */ + ip0 = ip1; + ip1 = ip2; + ip2 = ip3; + + /* write back hash table entry */ + current0 = (U32)(ip0 - base); + hashTable[hash0] = current0; + + { /* load match for ip[0] */ + U32 const mval = idx >= dictStartIndex ? + MEM_read32(idxBase + idx) : + MEM_read32(ip0) ^ 1; /* guaranteed not to match */ + + /* check match at ip[0] */ + if (MEM_read32(ip0) == mval) { + /* found a match! */ + goto _offset; + } } + + /* lookup ip[1] */ + idx = hashTable[hash1]; + idxBase = idx < prefixStartIndex ? dictBase : base; + + /* hash ip[2] */ + hash0 = hash1; + hash1 = ZSTD_hashPtr(ip2, hlog, mls); + + /* advance to next positions */ + ip0 = ip1; + ip1 = ip2; + ip2 = ip0 + step; + ip3 = ip1 + step; + + /* calculate step */ + if (ip2 >= nextStep) { + step++; + PREFETCH_L1(ip1 + 64); + PREFETCH_L1(ip1 + 128); + nextStep += kStepIncr; + } + } while (ip3 < ilimit); + +_cleanup: + /* Note that there are probably still a couple positions we could search. + * However, it seems to be a meaningful performance hit to try to search + * them. So let's not. */ + + /* If offset_1 started invalid (offsetSaved1 != 0) and became valid (offset_1 != 0), + * rotate saved offsets. See comment in ZSTD_compressBlock_fast_noDict for more context. */ + offsetSaved2 = ((offsetSaved1 != 0) && (offset_1 != 0)) ? offsetSaved1 : offsetSaved2; /* save reps for next block */ - rep[0] = offset_1; - rep[1] = offset_2; + rep[0] = offset_1 ? offset_1 : offsetSaved1; + rep[1] = offset_2 ? offset_2 : offsetSaved2; /* Return the last literals size */ return (size_t)(iend - anchor); + +_offset: /* Requires: ip0, idx, idxBase */ + + /* Compute the offset code. */ + { U32 const offset = current0 - idx; + const BYTE* const lowMatchPtr = idx < prefixStartIndex ? dictStart : prefixStart; + matchEnd = idx < prefixStartIndex ? dictEnd : iend; + match0 = idxBase + idx; + offset_2 = offset_1; + offset_1 = offset; + offcode = OFFSET_TO_OFFBASE(offset); + mLength = 4; + + /* Count the backwards match length. */ + while (((ip0>anchor) & (match0>lowMatchPtr)) && (ip0[-1] == match0[-1])) { + ip0--; + match0--; + mLength++; + } } + +_match: /* Requires: ip0, match0, offcode, matchEnd */ + + /* Count the forward length. */ + assert(matchEnd != 0); + mLength += ZSTD_count_2segments(ip0 + mLength, match0 + mLength, iend, matchEnd, prefixStart); + + ZSTD_storeSeq(seqStore, (size_t)(ip0 - anchor), anchor, iend, offcode, mLength); + + ip0 += mLength; + anchor = ip0; + + /* write next hash table entry */ + if (ip1 < ip0) { + hashTable[hash1] = (U32)(ip1 - base); + } + + /* Fill table and check for immediate repcode. */ + if (ip0 <= ilimit) { + /* Fill Table */ + assert(base+current0+2 > istart); /* check base overflow */ + hashTable[ZSTD_hashPtr(base+current0+2, hlog, mls)] = current0+2; /* here because current+2 could be > iend-8 */ + hashTable[ZSTD_hashPtr(ip0-2, hlog, mls)] = (U32)(ip0-2-base); + + while (ip0 <= ilimit) { + U32 const repIndex2 = (U32)(ip0-base) - offset_2; + const BYTE* const repMatch2 = repIndex2 < prefixStartIndex ? dictBase + repIndex2 : base + repIndex2; + if ( ((ZSTD_index_overlap_check(prefixStartIndex, repIndex2)) & (offset_2 > 0)) + && (MEM_read32(repMatch2) == MEM_read32(ip0)) ) { + const BYTE* const repEnd2 = repIndex2 < prefixStartIndex ? dictEnd : iend; + size_t const repLength2 = ZSTD_count_2segments(ip0+4, repMatch2+4, iend, repEnd2, prefixStart) + 4; + { U32 const tmpOffset = offset_2; offset_2 = offset_1; offset_1 = tmpOffset; } /* swap offset_2 <=> offset_1 */ + ZSTD_storeSeq(seqStore, 0 /*litlen*/, anchor, iend, REPCODE1_TO_OFFBASE, repLength2); + hashTable[ZSTD_hashPtr(ip0, hlog, mls)] = (U32)(ip0-base); + ip0 += repLength2; + anchor = ip0; + continue; + } + break; + } } + + goto _start; } ZSTD_GEN_FAST_FN(extDict, 4, 0) @@ -656,10 +964,11 @@ ZSTD_GEN_FAST_FN(extDict, 6, 0) ZSTD_GEN_FAST_FN(extDict, 7, 0) size_t ZSTD_compressBlock_fast_extDict( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { U32 const mls = ms->cParams.minMatch; + assert(ms->dictMatchState == NULL); switch(mls) { default: /* includes case 3 */ diff --git a/lib/zstd/compress/zstd_fast.h b/lib/zstd/compress/zstd_fast.h index fddc2f532d21..04fde0a72a4e 100644 --- a/lib/zstd/compress/zstd_fast.h +++ b/lib/zstd/compress/zstd_fast.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -11,21 +12,20 @@ #ifndef ZSTD_FAST_H #define ZSTD_FAST_H - #include "../common/mem.h" /* U32 */ #include "zstd_compress_internal.h" -void ZSTD_fillHashTable(ZSTD_matchState_t* ms, - void const* end, ZSTD_dictTableLoadMethod_e dtlm); +void ZSTD_fillHashTable(ZSTD_MatchState_t* ms, + void const* end, ZSTD_dictTableLoadMethod_e dtlm, + ZSTD_tableFillPurpose_e tfp); size_t ZSTD_compressBlock_fast( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); size_t ZSTD_compressBlock_fast_dictMatchState( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); size_t ZSTD_compressBlock_fast_extDict( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); - #endif /* ZSTD_FAST_H */ diff --git a/lib/zstd/compress/zstd_lazy.c b/lib/zstd/compress/zstd_lazy.c index 0298a01a7504..88e2501fe3ef 100644 --- a/lib/zstd/compress/zstd_lazy.c +++ b/lib/zstd/compress/zstd_lazy.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -10,14 +11,23 @@ #include "zstd_compress_internal.h" #include "zstd_lazy.h" +#include "../common/bits.h" /* ZSTD_countTrailingZeros64 */ + +#if !defined(ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR) + +#define kLazySkippingStep 8 /*-************************************* * Binary Tree search ***************************************/ -static void -ZSTD_updateDUBT(ZSTD_matchState_t* ms, +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +void ZSTD_updateDUBT(ZSTD_MatchState_t* ms, const BYTE* ip, const BYTE* iend, U32 mls) { @@ -60,8 +70,9 @@ ZSTD_updateDUBT(ZSTD_matchState_t* ms, * sort one already inserted but unsorted position * assumption : curr >= btlow == (curr - btmask) * doesn't fail */ -static void -ZSTD_insertDUBT1(const ZSTD_matchState_t* ms, +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +void ZSTD_insertDUBT1(const ZSTD_MatchState_t* ms, U32 curr, const BYTE* inputEnd, U32 nbCompares, U32 btLow, const ZSTD_dictMode_e dictMode) @@ -149,9 +160,10 @@ ZSTD_insertDUBT1(const ZSTD_matchState_t* ms, } -static size_t -ZSTD_DUBT_findBetterDictMatch ( - const ZSTD_matchState_t* ms, +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +size_t ZSTD_DUBT_findBetterDictMatch ( + const ZSTD_MatchState_t* ms, const BYTE* const ip, const BYTE* const iend, size_t* offsetPtr, size_t bestLength, @@ -159,7 +171,7 @@ ZSTD_DUBT_findBetterDictMatch ( U32 const mls, const ZSTD_dictMode_e dictMode) { - const ZSTD_matchState_t * const dms = ms->dictMatchState; + const ZSTD_MatchState_t * const dms = ms->dictMatchState; const ZSTD_compressionParameters* const dmsCParams = &dms->cParams; const U32 * const dictHashTable = dms->hashTable; U32 const hashLog = dmsCParams->hashLog; @@ -197,8 +209,8 @@ ZSTD_DUBT_findBetterDictMatch ( U32 matchIndex = dictMatchIndex + dictIndexDelta; if ( (4*(int)(matchLength-bestLength)) > (int)(ZSTD_highbit32(curr-matchIndex+1) - ZSTD_highbit32((U32)offsetPtr[0]+1)) ) { DEBUGLOG(9, "ZSTD_DUBT_findBetterDictMatch(%u) : found better match length %u -> %u and offsetCode %u -> %u (dictMatchIndex %u, matchIndex %u)", - curr, (U32)bestLength, (U32)matchLength, (U32)*offsetPtr, STORE_OFFSET(curr - matchIndex), dictMatchIndex, matchIndex); - bestLength = matchLength, *offsetPtr = STORE_OFFSET(curr - matchIndex); + curr, (U32)bestLength, (U32)matchLength, (U32)*offsetPtr, OFFSET_TO_OFFBASE(curr - matchIndex), dictMatchIndex, matchIndex); + bestLength = matchLength, *offsetPtr = OFFSET_TO_OFFBASE(curr - matchIndex); } if (ip+matchLength == iend) { /* reached end of input : ip[matchLength] is not valid, no way to know if it's larger or smaller than match */ break; /* drop, to guarantee consistency (miss a little bit of compression) */ @@ -218,7 +230,7 @@ ZSTD_DUBT_findBetterDictMatch ( } if (bestLength >= MINMATCH) { - U32 const mIndex = curr - (U32)STORED_OFFSET(*offsetPtr); (void)mIndex; + U32 const mIndex = curr - (U32)OFFBASE_TO_OFFSET(*offsetPtr); (void)mIndex; DEBUGLOG(8, "ZSTD_DUBT_findBetterDictMatch(%u) : found match of length %u and offsetCode %u (pos %u)", curr, (U32)bestLength, (U32)*offsetPtr, mIndex); } @@ -227,10 +239,11 @@ ZSTD_DUBT_findBetterDictMatch ( } -static size_t -ZSTD_DUBT_findBestMatch(ZSTD_matchState_t* ms, +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +size_t ZSTD_DUBT_findBestMatch(ZSTD_MatchState_t* ms, const BYTE* const ip, const BYTE* const iend, - size_t* offsetPtr, + size_t* offBasePtr, U32 const mls, const ZSTD_dictMode_e dictMode) { @@ -327,8 +340,8 @@ ZSTD_DUBT_findBestMatch(ZSTD_matchState_t* ms, if (matchLength > bestLength) { if (matchLength > matchEndIdx - matchIndex) matchEndIdx = matchIndex + (U32)matchLength; - if ( (4*(int)(matchLength-bestLength)) > (int)(ZSTD_highbit32(curr-matchIndex+1) - ZSTD_highbit32((U32)offsetPtr[0]+1)) ) - bestLength = matchLength, *offsetPtr = STORE_OFFSET(curr - matchIndex); + if ( (4*(int)(matchLength-bestLength)) > (int)(ZSTD_highbit32(curr - matchIndex + 1) - ZSTD_highbit32((U32)*offBasePtr)) ) + bestLength = matchLength, *offBasePtr = OFFSET_TO_OFFBASE(curr - matchIndex); if (ip+matchLength == iend) { /* equal : no way to know if inf or sup */ if (dictMode == ZSTD_dictMatchState) { nbCompares = 0; /* in addition to avoiding checking any @@ -361,16 +374,16 @@ ZSTD_DUBT_findBestMatch(ZSTD_matchState_t* ms, if (dictMode == ZSTD_dictMatchState && nbCompares) { bestLength = ZSTD_DUBT_findBetterDictMatch( ms, ip, iend, - offsetPtr, bestLength, nbCompares, + offBasePtr, bestLength, nbCompares, mls, dictMode); } assert(matchEndIdx > curr+8); /* ensure nextToUpdate is increased */ ms->nextToUpdate = matchEndIdx - 8; /* skip repetitive patterns */ if (bestLength >= MINMATCH) { - U32 const mIndex = curr - (U32)STORED_OFFSET(*offsetPtr); (void)mIndex; + U32 const mIndex = curr - (U32)OFFBASE_TO_OFFSET(*offBasePtr); (void)mIndex; DEBUGLOG(8, "ZSTD_DUBT_findBestMatch(%u) : found match of length %u and offsetCode %u (pos %u)", - curr, (U32)bestLength, (U32)*offsetPtr, mIndex); + curr, (U32)bestLength, (U32)*offBasePtr, mIndex); } return bestLength; } @@ -378,24 +391,25 @@ ZSTD_DUBT_findBestMatch(ZSTD_matchState_t* ms, /* ZSTD_BtFindBestMatch() : Tree updater, providing best match */ -FORCE_INLINE_TEMPLATE size_t -ZSTD_BtFindBestMatch( ZSTD_matchState_t* ms, +FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +size_t ZSTD_BtFindBestMatch( ZSTD_MatchState_t* ms, const BYTE* const ip, const BYTE* const iLimit, - size_t* offsetPtr, + size_t* offBasePtr, const U32 mls /* template */, const ZSTD_dictMode_e dictMode) { DEBUGLOG(7, "ZSTD_BtFindBestMatch"); if (ip < ms->window.base + ms->nextToUpdate) return 0; /* skipped area */ ZSTD_updateDUBT(ms, ip, iLimit, mls); - return ZSTD_DUBT_findBestMatch(ms, ip, iLimit, offsetPtr, mls, dictMode); + return ZSTD_DUBT_findBestMatch(ms, ip, iLimit, offBasePtr, mls, dictMode); } /* ********************************* * Dedicated dict search ***********************************/ -void ZSTD_dedicatedDictSearch_lazy_loadDictionary(ZSTD_matchState_t* ms, const BYTE* const ip) +void ZSTD_dedicatedDictSearch_lazy_loadDictionary(ZSTD_MatchState_t* ms, const BYTE* const ip) { const BYTE* const base = ms->window.base; U32 const target = (U32)(ip - base); @@ -514,7 +528,7 @@ void ZSTD_dedicatedDictSearch_lazy_loadDictionary(ZSTD_matchState_t* ms, const B */ FORCE_INLINE_TEMPLATE size_t ZSTD_dedicatedDictSearch_lazy_search(size_t* offsetPtr, size_t ml, U32 nbAttempts, - const ZSTD_matchState_t* const dms, + const ZSTD_MatchState_t* const dms, const BYTE* const ip, const BYTE* const iLimit, const BYTE* const prefixStart, const U32 curr, const U32 dictLimit, const size_t ddsIdx) { @@ -561,7 +575,7 @@ size_t ZSTD_dedicatedDictSearch_lazy_search(size_t* offsetPtr, size_t ml, U32 nb /* save best solution */ if (currentMl > ml) { ml = currentMl; - *offsetPtr = STORE_OFFSET(curr - (matchIndex + ddsIndexDelta)); + *offsetPtr = OFFSET_TO_OFFBASE(curr - (matchIndex + ddsIndexDelta)); if (ip+currentMl == iLimit) { /* best possible, avoids read overflow on next attempt */ return ml; @@ -598,7 +612,7 @@ size_t ZSTD_dedicatedDictSearch_lazy_search(size_t* offsetPtr, size_t ml, U32 nb /* save best solution */ if (currentMl > ml) { ml = currentMl; - *offsetPtr = STORE_OFFSET(curr - (matchIndex + ddsIndexDelta)); + *offsetPtr = OFFSET_TO_OFFBASE(curr - (matchIndex + ddsIndexDelta)); if (ip+currentMl == iLimit) break; /* best possible, avoids read overflow on next attempt */ } } @@ -614,10 +628,12 @@ size_t ZSTD_dedicatedDictSearch_lazy_search(size_t* offsetPtr, size_t ml, U32 nb /* Update chains up to ip (excluded) Assumption : always within prefix (i.e. not within extDict) */ -FORCE_INLINE_TEMPLATE U32 ZSTD_insertAndFindFirstIndex_internal( - ZSTD_matchState_t* ms, +FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +U32 ZSTD_insertAndFindFirstIndex_internal( + ZSTD_MatchState_t* ms, const ZSTD_compressionParameters* const cParams, - const BYTE* ip, U32 const mls) + const BYTE* ip, U32 const mls, U32 const lazySkipping) { U32* const hashTable = ms->hashTable; const U32 hashLog = cParams->hashLog; @@ -632,21 +648,25 @@ FORCE_INLINE_TEMPLATE U32 ZSTD_insertAndFindFirstIndex_internal( NEXT_IN_CHAIN(idx, chainMask) = hashTable[h]; hashTable[h] = idx; idx++; + /* Stop inserting every position when in the lazy skipping mode. */ + if (lazySkipping) + break; } ms->nextToUpdate = target; return hashTable[ZSTD_hashPtr(ip, hashLog, mls)]; } -U32 ZSTD_insertAndFindFirstIndex(ZSTD_matchState_t* ms, const BYTE* ip) { +U32 ZSTD_insertAndFindFirstIndex(ZSTD_MatchState_t* ms, const BYTE* ip) { const ZSTD_compressionParameters* const cParams = &ms->cParams; - return ZSTD_insertAndFindFirstIndex_internal(ms, cParams, ip, ms->cParams.minMatch); + return ZSTD_insertAndFindFirstIndex_internal(ms, cParams, ip, ms->cParams.minMatch, /* lazySkipping*/ 0); } /* inlining is important to hardwire a hot branch (template emulation) */ FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR size_t ZSTD_HcFindBestMatch( - ZSTD_matchState_t* ms, + ZSTD_MatchState_t* ms, const BYTE* const ip, const BYTE* const iLimit, size_t* offsetPtr, const U32 mls, const ZSTD_dictMode_e dictMode) @@ -670,7 +690,7 @@ size_t ZSTD_HcFindBestMatch( U32 nbAttempts = 1U << cParams->searchLog; size_t ml=4-1; - const ZSTD_matchState_t* const dms = ms->dictMatchState; + const ZSTD_MatchState_t* const dms = ms->dictMatchState; const U32 ddsHashLog = dictMode == ZSTD_dedicatedDictSearch ? dms->cParams.hashLog - ZSTD_LAZY_DDSS_BUCKET_LOG : 0; const size_t ddsIdx = dictMode == ZSTD_dedicatedDictSearch @@ -684,14 +704,15 @@ size_t ZSTD_HcFindBestMatch( } /* HC4 match finder */ - matchIndex = ZSTD_insertAndFindFirstIndex_internal(ms, cParams, ip, mls); + matchIndex = ZSTD_insertAndFindFirstIndex_internal(ms, cParams, ip, mls, ms->lazySkipping); for ( ; (matchIndex>=lowLimit) & (nbAttempts>0) ; nbAttempts--) { size_t currentMl=0; if ((dictMode != ZSTD_extDict) || matchIndex >= dictLimit) { const BYTE* const match = base + matchIndex; assert(matchIndex >= dictLimit); /* ensures this is true if dictMode != ZSTD_extDict */ - if (match[ml] == ip[ml]) /* potentially better */ + /* read 4B starting from (match + ml + 1 - sizeof(U32)) */ + if (MEM_read32(match + ml - 3) == MEM_read32(ip + ml - 3)) /* potentially better */ currentMl = ZSTD_count(ip, match, iLimit); } else { const BYTE* const match = dictBase + matchIndex; @@ -703,7 +724,7 @@ size_t ZSTD_HcFindBestMatch( /* save best solution */ if (currentMl > ml) { ml = currentMl; - *offsetPtr = STORE_OFFSET(curr - matchIndex); + *offsetPtr = OFFSET_TO_OFFBASE(curr - matchIndex); if (ip+currentMl == iLimit) break; /* best possible, avoids read overflow on next attempt */ } @@ -739,7 +760,7 @@ size_t ZSTD_HcFindBestMatch( if (currentMl > ml) { ml = currentMl; assert(curr > matchIndex + dmsIndexDelta); - *offsetPtr = STORE_OFFSET(curr - (matchIndex + dmsIndexDelta)); + *offsetPtr = OFFSET_TO_OFFBASE(curr - (matchIndex + dmsIndexDelta)); if (ip+currentMl == iLimit) break; /* best possible, avoids read overflow on next attempt */ } @@ -756,8 +777,6 @@ size_t ZSTD_HcFindBestMatch( * (SIMD) Row-based matchfinder ***********************************/ /* Constants for row-based hash */ -#define ZSTD_ROW_HASH_TAG_OFFSET 16 /* byte offset of hashes in the match state's tagTable from the beginning of a row */ -#define ZSTD_ROW_HASH_TAG_BITS 8 /* nb bits to use for the tag */ #define ZSTD_ROW_HASH_TAG_MASK ((1u << ZSTD_ROW_HASH_TAG_BITS) - 1) #define ZSTD_ROW_HASH_MAX_ENTRIES 64 /* absolute maximum number of entries per row, for all configurations */ @@ -769,64 +788,19 @@ typedef U64 ZSTD_VecMask; /* Clarifies when we are interacting with a U64 repr * Starting from the LSB, returns the idx of the next non-zero bit. * Basically counting the nb of trailing zeroes. */ -static U32 ZSTD_VecMask_next(ZSTD_VecMask val) { - assert(val != 0); -# if (defined(__GNUC__) && ((__GNUC__ > 3) || ((__GNUC__ == 3) && (__GNUC_MINOR__ >= 4)))) - if (sizeof(size_t) == 4) { - U32 mostSignificantWord = (U32)(val >> 32); - U32 leastSignificantWord = (U32)val; - if (leastSignificantWord == 0) { - return 32 + (U32)__builtin_ctz(mostSignificantWord); - } else { - return (U32)__builtin_ctz(leastSignificantWord); - } - } else { - return (U32)__builtin_ctzll(val); - } -# else - /* Software ctz version: http://aggregate.org/MAGIC/#Trailing%20Zero%20Count - * and: https://stackoverflow.com/questions/2709430/count-number-of-bits-in-a-64-bit-long-big-integer - */ - val = ~val & (val - 1ULL); /* Lowest set bit mask */ - val = val - ((val >> 1) & 0x5555555555555555); - val = (val & 0x3333333333333333ULL) + ((val >> 2) & 0x3333333333333333ULL); - return (U32)((((val + (val >> 4)) & 0xF0F0F0F0F0F0F0FULL) * 0x101010101010101ULL) >> 56); -# endif -} - -/* ZSTD_rotateRight_*(): - * Rotates a bitfield to the right by "count" bits. - * https://en.wikipedia.org/w/index.php?title=Circular_shift&oldid=991635599#Implementing_circular_shifts - */ -FORCE_INLINE_TEMPLATE -U64 ZSTD_rotateRight_U64(U64 const value, U32 count) { - assert(count < 64); - count &= 0x3F; /* for fickle pattern recognition */ - return (value >> count) | (U64)(value << ((0U - count) & 0x3F)); -} - -FORCE_INLINE_TEMPLATE -U32 ZSTD_rotateRight_U32(U32 const value, U32 count) { - assert(count < 32); - count &= 0x1F; /* for fickle pattern recognition */ - return (value >> count) | (U32)(value << ((0U - count) & 0x1F)); -} - -FORCE_INLINE_TEMPLATE -U16 ZSTD_rotateRight_U16(U16 const value, U32 count) { - assert(count < 16); - count &= 0x0F; /* for fickle pattern recognition */ - return (value >> count) | (U16)(value << ((0U - count) & 0x0F)); +MEM_STATIC U32 ZSTD_VecMask_next(ZSTD_VecMask val) { + return ZSTD_countTrailingZeros64(val); } /* ZSTD_row_nextIndex(): * Returns the next index to insert at within a tagTable row, and updates the "head" - * value to reflect the update. Essentially cycles backwards from [0, {entries per row}) + * value to reflect the update. Essentially cycles backwards from [1, {entries per row}) */ FORCE_INLINE_TEMPLATE U32 ZSTD_row_nextIndex(BYTE* const tagRow, U32 const rowMask) { - U32 const next = (*tagRow - 1) & rowMask; - *tagRow = (BYTE)next; - return next; + U32 next = (*tagRow-1) & rowMask; + next += (next == 0) ? rowMask : 0; /* skip first position */ + *tagRow = (BYTE)next; + return next; } /* ZSTD_isAligned(): @@ -840,7 +814,7 @@ MEM_STATIC int ZSTD_isAligned(void const* ptr, size_t align) { /* ZSTD_row_prefetch(): * Performs prefetching for the hashTable and tagTable at a given row. */ -FORCE_INLINE_TEMPLATE void ZSTD_row_prefetch(U32 const* hashTable, U16 const* tagTable, U32 const relRow, U32 const rowLog) { +FORCE_INLINE_TEMPLATE void ZSTD_row_prefetch(U32 const* hashTable, BYTE const* tagTable, U32 const relRow, U32 const rowLog) { PREFETCH_L1(hashTable + relRow); if (rowLog >= 5) { PREFETCH_L1(hashTable + relRow + 16); @@ -859,18 +833,20 @@ FORCE_INLINE_TEMPLATE void ZSTD_row_prefetch(U32 const* hashTable, U16 const* ta * Fill up the hash cache starting at idx, prefetching up to ZSTD_ROW_HASH_CACHE_SIZE entries, * but not beyond iLimit. */ -FORCE_INLINE_TEMPLATE void ZSTD_row_fillHashCache(ZSTD_matchState_t* ms, const BYTE* base, +FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +void ZSTD_row_fillHashCache(ZSTD_MatchState_t* ms, const BYTE* base, U32 const rowLog, U32 const mls, U32 idx, const BYTE* const iLimit) { U32 const* const hashTable = ms->hashTable; - U16 const* const tagTable = ms->tagTable; + BYTE const* const tagTable = ms->tagTable; U32 const hashLog = ms->rowHashLog; U32 const maxElemsToPrefetch = (base + idx) > iLimit ? 0 : (U32)(iLimit - (base + idx) + 1); U32 const lim = idx + MIN(ZSTD_ROW_HASH_CACHE_SIZE, maxElemsToPrefetch); for (; idx < lim; ++idx) { - U32 const hash = (U32)ZSTD_hashPtr(base + idx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls); + U32 const hash = (U32)ZSTD_hashPtrSalted(base + idx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls, ms->hashSalt); U32 const row = (hash >> ZSTD_ROW_HASH_TAG_BITS) << rowLog; ZSTD_row_prefetch(hashTable, tagTable, row, rowLog); ms->hashCache[idx & ZSTD_ROW_HASH_CACHE_MASK] = hash; @@ -885,12 +861,15 @@ FORCE_INLINE_TEMPLATE void ZSTD_row_fillHashCache(ZSTD_matchState_t* ms, const B * Returns the hash of base + idx, and replaces the hash in the hash cache with the byte at * base + idx + ZSTD_ROW_HASH_CACHE_SIZE. Also prefetches the appropriate rows from hashTable and tagTable. */ -FORCE_INLINE_TEMPLATE U32 ZSTD_row_nextCachedHash(U32* cache, U32 const* hashTable, - U16 const* tagTable, BYTE const* base, +FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +U32 ZSTD_row_nextCachedHash(U32* cache, U32 const* hashTable, + BYTE const* tagTable, BYTE const* base, U32 idx, U32 const hashLog, - U32 const rowLog, U32 const mls) + U32 const rowLog, U32 const mls, + U64 const hashSalt) { - U32 const newHash = (U32)ZSTD_hashPtr(base+idx+ZSTD_ROW_HASH_CACHE_SIZE, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls); + U32 const newHash = (U32)ZSTD_hashPtrSalted(base+idx+ZSTD_ROW_HASH_CACHE_SIZE, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls, hashSalt); U32 const row = (newHash >> ZSTD_ROW_HASH_TAG_BITS) << rowLog; ZSTD_row_prefetch(hashTable, tagTable, row, rowLog); { U32 const hash = cache[idx & ZSTD_ROW_HASH_CACHE_MASK]; @@ -902,28 +881,29 @@ FORCE_INLINE_TEMPLATE U32 ZSTD_row_nextCachedHash(U32* cache, U32 const* hashTab /* ZSTD_row_update_internalImpl(): * Updates the hash table with positions starting from updateStartIdx until updateEndIdx. */ -FORCE_INLINE_TEMPLATE void ZSTD_row_update_internalImpl(ZSTD_matchState_t* ms, - U32 updateStartIdx, U32 const updateEndIdx, - U32 const mls, U32 const rowLog, - U32 const rowMask, U32 const useCache) +FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +void ZSTD_row_update_internalImpl(ZSTD_MatchState_t* ms, + U32 updateStartIdx, U32 const updateEndIdx, + U32 const mls, U32 const rowLog, + U32 const rowMask, U32 const useCache) { U32* const hashTable = ms->hashTable; - U16* const tagTable = ms->tagTable; + BYTE* const tagTable = ms->tagTable; U32 const hashLog = ms->rowHashLog; const BYTE* const base = ms->window.base; DEBUGLOG(6, "ZSTD_row_update_internalImpl(): updateStartIdx=%u, updateEndIdx=%u", updateStartIdx, updateEndIdx); for (; updateStartIdx < updateEndIdx; ++updateStartIdx) { - U32 const hash = useCache ? ZSTD_row_nextCachedHash(ms->hashCache, hashTable, tagTable, base, updateStartIdx, hashLog, rowLog, mls) - : (U32)ZSTD_hashPtr(base + updateStartIdx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls); + U32 const hash = useCache ? ZSTD_row_nextCachedHash(ms->hashCache, hashTable, tagTable, base, updateStartIdx, hashLog, rowLog, mls, ms->hashSalt) + : (U32)ZSTD_hashPtrSalted(base + updateStartIdx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls, ms->hashSalt); U32 const relRow = (hash >> ZSTD_ROW_HASH_TAG_BITS) << rowLog; U32* const row = hashTable + relRow; - BYTE* tagRow = (BYTE*)(tagTable + relRow); /* Though tagTable is laid out as a table of U16, each tag is only 1 byte. - Explicit cast allows us to get exact desired position within each row */ + BYTE* tagRow = tagTable + relRow; U32 const pos = ZSTD_row_nextIndex(tagRow, rowMask); - assert(hash == ZSTD_hashPtr(base + updateStartIdx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls)); - ((BYTE*)tagRow)[pos + ZSTD_ROW_HASH_TAG_OFFSET] = hash & ZSTD_ROW_HASH_TAG_MASK; + assert(hash == ZSTD_hashPtrSalted(base + updateStartIdx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls, ms->hashSalt)); + tagRow[pos] = hash & ZSTD_ROW_HASH_TAG_MASK; row[pos] = updateStartIdx; } } @@ -932,9 +912,11 @@ FORCE_INLINE_TEMPLATE void ZSTD_row_update_internalImpl(ZSTD_matchState_t* ms, * Inserts the byte at ip into the appropriate position in the hash table, and updates ms->nextToUpdate. * Skips sections of long matches as is necessary. */ -FORCE_INLINE_TEMPLATE void ZSTD_row_update_internal(ZSTD_matchState_t* ms, const BYTE* ip, - U32 const mls, U32 const rowLog, - U32 const rowMask, U32 const useCache) +FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +void ZSTD_row_update_internal(ZSTD_MatchState_t* ms, const BYTE* ip, + U32 const mls, U32 const rowLog, + U32 const rowMask, U32 const useCache) { U32 idx = ms->nextToUpdate; const BYTE* const base = ms->window.base; @@ -965,13 +947,41 @@ FORCE_INLINE_TEMPLATE void ZSTD_row_update_internal(ZSTD_matchState_t* ms, const * External wrapper for ZSTD_row_update_internal(). Used for filling the hashtable during dictionary * processing. */ -void ZSTD_row_update(ZSTD_matchState_t* const ms, const BYTE* ip) { +void ZSTD_row_update(ZSTD_MatchState_t* const ms, const BYTE* ip) { const U32 rowLog = BOUNDED(4, ms->cParams.searchLog, 6); const U32 rowMask = (1u << rowLog) - 1; const U32 mls = MIN(ms->cParams.minMatch, 6 /* mls caps out at 6 */); DEBUGLOG(5, "ZSTD_row_update(), rowLog=%u", rowLog); - ZSTD_row_update_internal(ms, ip, mls, rowLog, rowMask, 0 /* dont use cache */); + ZSTD_row_update_internal(ms, ip, mls, rowLog, rowMask, 0 /* don't use cache */); +} + +/* Returns the mask width of bits group of which will be set to 1. Given not all + * architectures have easy movemask instruction, this helps to iterate over + * groups of bits easier and faster. + */ +FORCE_INLINE_TEMPLATE U32 +ZSTD_row_matchMaskGroupWidth(const U32 rowEntries) +{ + assert((rowEntries == 16) || (rowEntries == 32) || rowEntries == 64); + assert(rowEntries <= ZSTD_ROW_HASH_MAX_ENTRIES); + (void)rowEntries; +#if defined(ZSTD_ARCH_ARM_NEON) + /* NEON path only works for little endian */ + if (!MEM_isLittleEndian()) { + return 1; + } + if (rowEntries == 16) { + return 4; + } + if (rowEntries == 32) { + return 2; + } + if (rowEntries == 64) { + return 1; + } +#endif + return 1; } #if defined(ZSTD_ARCH_X86_SSE2) @@ -994,71 +1004,82 @@ ZSTD_row_getSSEMask(int nbChunks, const BYTE* const src, const BYTE tag, const U } #endif -/* Returns a ZSTD_VecMask (U32) that has the nth bit set to 1 if the newly-computed "tag" matches - * the hash at the nth position in a row of the tagTable. - * Each row is a circular buffer beginning at the value of "head". So we must rotate the "matches" bitfield - * to match up with the actual layout of the entries within the hashTable */ +#if defined(ZSTD_ARCH_ARM_NEON) +FORCE_INLINE_TEMPLATE ZSTD_VecMask +ZSTD_row_getNEONMask(const U32 rowEntries, const BYTE* const src, const BYTE tag, const U32 headGrouped) +{ + assert((rowEntries == 16) || (rowEntries == 32) || rowEntries == 64); + if (rowEntries == 16) { + /* vshrn_n_u16 shifts by 4 every u16 and narrows to 8 lower bits. + * After that groups of 4 bits represent the equalMask. We lower + * all bits except the highest in these groups by doing AND with + * 0x88 = 0b10001000. + */ + const uint8x16_t chunk = vld1q_u8(src); + const uint16x8_t equalMask = vreinterpretq_u16_u8(vceqq_u8(chunk, vdupq_n_u8(tag))); + const uint8x8_t res = vshrn_n_u16(equalMask, 4); + const U64 matches = vget_lane_u64(vreinterpret_u64_u8(res), 0); + return ZSTD_rotateRight_U64(matches, headGrouped) & 0x8888888888888888ull; + } else if (rowEntries == 32) { + /* Same idea as with rowEntries == 16 but doing AND with + * 0x55 = 0b01010101. + */ + const uint16x8x2_t chunk = vld2q_u16((const uint16_t*)(const void*)src); + const uint8x16_t chunk0 = vreinterpretq_u8_u16(chunk.val[0]); + const uint8x16_t chunk1 = vreinterpretq_u8_u16(chunk.val[1]); + const uint8x16_t dup = vdupq_n_u8(tag); + const uint8x8_t t0 = vshrn_n_u16(vreinterpretq_u16_u8(vceqq_u8(chunk0, dup)), 6); + const uint8x8_t t1 = vshrn_n_u16(vreinterpretq_u16_u8(vceqq_u8(chunk1, dup)), 6); + const uint8x8_t res = vsli_n_u8(t0, t1, 4); + const U64 matches = vget_lane_u64(vreinterpret_u64_u8(res), 0) ; + return ZSTD_rotateRight_U64(matches, headGrouped) & 0x5555555555555555ull; + } else { /* rowEntries == 64 */ + const uint8x16x4_t chunk = vld4q_u8(src); + const uint8x16_t dup = vdupq_n_u8(tag); + const uint8x16_t cmp0 = vceqq_u8(chunk.val[0], dup); + const uint8x16_t cmp1 = vceqq_u8(chunk.val[1], dup); + const uint8x16_t cmp2 = vceqq_u8(chunk.val[2], dup); + const uint8x16_t cmp3 = vceqq_u8(chunk.val[3], dup); + + const uint8x16_t t0 = vsriq_n_u8(cmp1, cmp0, 1); + const uint8x16_t t1 = vsriq_n_u8(cmp3, cmp2, 1); + const uint8x16_t t2 = vsriq_n_u8(t1, t0, 2); + const uint8x16_t t3 = vsriq_n_u8(t2, t2, 4); + const uint8x8_t t4 = vshrn_n_u16(vreinterpretq_u16_u8(t3), 4); + const U64 matches = vget_lane_u64(vreinterpret_u64_u8(t4), 0); + return ZSTD_rotateRight_U64(matches, headGrouped); + } +} +#endif + +/* Returns a ZSTD_VecMask (U64) that has the nth group (determined by + * ZSTD_row_matchMaskGroupWidth) of bits set to 1 if the newly-computed "tag" + * matches the hash at the nth position in a row of the tagTable. + * Each row is a circular buffer beginning at the value of "headGrouped". So we + * must rotate the "matches" bitfield to match up with the actual layout of the + * entries within the hashTable */ FORCE_INLINE_TEMPLATE ZSTD_VecMask -ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 head, const U32 rowEntries) +ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 headGrouped, const U32 rowEntries) { - const BYTE* const src = tagRow + ZSTD_ROW_HASH_TAG_OFFSET; + const BYTE* const src = tagRow; assert((rowEntries == 16) || (rowEntries == 32) || rowEntries == 64); assert(rowEntries <= ZSTD_ROW_HASH_MAX_ENTRIES); + assert(ZSTD_row_matchMaskGroupWidth(rowEntries) * rowEntries <= sizeof(ZSTD_VecMask) * 8); #if defined(ZSTD_ARCH_X86_SSE2) - return ZSTD_row_getSSEMask(rowEntries / 16, src, tag, head); + return ZSTD_row_getSSEMask(rowEntries / 16, src, tag, headGrouped); #else /* SW or NEON-LE */ # if defined(ZSTD_ARCH_ARM_NEON) /* This NEON path only works for little endian - otherwise use SWAR below */ if (MEM_isLittleEndian()) { - if (rowEntries == 16) { - const uint8x16_t chunk = vld1q_u8(src); - const uint16x8_t equalMask = vreinterpretq_u16_u8(vceqq_u8(chunk, vdupq_n_u8(tag))); - const uint16x8_t t0 = vshlq_n_u16(equalMask, 7); - const uint32x4_t t1 = vreinterpretq_u32_u16(vsriq_n_u16(t0, t0, 14)); - const uint64x2_t t2 = vreinterpretq_u64_u32(vshrq_n_u32(t1, 14)); - const uint8x16_t t3 = vreinterpretq_u8_u64(vsraq_n_u64(t2, t2, 28)); - const U16 hi = (U16)vgetq_lane_u8(t3, 8); - const U16 lo = (U16)vgetq_lane_u8(t3, 0); - return ZSTD_rotateRight_U16((hi << 8) | lo, head); - } else if (rowEntries == 32) { - const uint16x8x2_t chunk = vld2q_u16((const U16*)(const void*)src); - const uint8x16_t chunk0 = vreinterpretq_u8_u16(chunk.val[0]); - const uint8x16_t chunk1 = vreinterpretq_u8_u16(chunk.val[1]); - const uint8x16_t equalMask0 = vceqq_u8(chunk0, vdupq_n_u8(tag)); - const uint8x16_t equalMask1 = vceqq_u8(chunk1, vdupq_n_u8(tag)); - const int8x8_t pack0 = vqmovn_s16(vreinterpretq_s16_u8(equalMask0)); - const int8x8_t pack1 = vqmovn_s16(vreinterpretq_s16_u8(equalMask1)); - const uint8x8_t t0 = vreinterpret_u8_s8(pack0); - const uint8x8_t t1 = vreinterpret_u8_s8(pack1); - const uint8x8_t t2 = vsri_n_u8(t1, t0, 2); - const uint8x8x2_t t3 = vuzp_u8(t2, t0); - const uint8x8_t t4 = vsri_n_u8(t3.val[1], t3.val[0], 4); - const U32 matches = vget_lane_u32(vreinterpret_u32_u8(t4), 0); - return ZSTD_rotateRight_U32(matches, head); - } else { /* rowEntries == 64 */ - const uint8x16x4_t chunk = vld4q_u8(src); - const uint8x16_t dup = vdupq_n_u8(tag); - const uint8x16_t cmp0 = vceqq_u8(chunk.val[0], dup); - const uint8x16_t cmp1 = vceqq_u8(chunk.val[1], dup); - const uint8x16_t cmp2 = vceqq_u8(chunk.val[2], dup); - const uint8x16_t cmp3 = vceqq_u8(chunk.val[3], dup); - - const uint8x16_t t0 = vsriq_n_u8(cmp1, cmp0, 1); - const uint8x16_t t1 = vsriq_n_u8(cmp3, cmp2, 1); - const uint8x16_t t2 = vsriq_n_u8(t1, t0, 2); - const uint8x16_t t3 = vsriq_n_u8(t2, t2, 4); - const uint8x8_t t4 = vshrn_n_u16(vreinterpretq_u16_u8(t3), 4); - const U64 matches = vget_lane_u64(vreinterpret_u64_u8(t4), 0); - return ZSTD_rotateRight_U64(matches, head); - } + return ZSTD_row_getNEONMask(rowEntries, src, tag, headGrouped); } # endif /* ZSTD_ARCH_ARM_NEON */ /* SWAR */ - { const size_t chunkSize = sizeof(size_t); + { const int chunkSize = sizeof(size_t); const size_t shiftAmount = ((chunkSize * 8) - chunkSize); const size_t xFF = ~((size_t)0); const size_t x01 = xFF / 0xFF; @@ -1091,11 +1112,11 @@ ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 head, } matches = ~matches; if (rowEntries == 16) { - return ZSTD_rotateRight_U16((U16)matches, head); + return ZSTD_rotateRight_U16((U16)matches, headGrouped); } else if (rowEntries == 32) { - return ZSTD_rotateRight_U32((U32)matches, head); + return ZSTD_rotateRight_U32((U32)matches, headGrouped); } else { - return ZSTD_rotateRight_U64((U64)matches, head); + return ZSTD_rotateRight_U64((U64)matches, headGrouped); } } #endif @@ -1103,29 +1124,30 @@ ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 head, /* The high-level approach of the SIMD row based match finder is as follows: * - Figure out where to insert the new entry: - * - Generate a hash from a byte along with an additional 1-byte "short hash". The additional byte is our "tag" - * - The hashTable is effectively split into groups or "rows" of 16 or 32 entries of U32, and the hash determines + * - Generate a hash for current input position and split it into a one byte of tag and `rowHashLog` bits of index. + * - The hash is salted by a value that changes on every context reset, so when the same table is used + * we will avoid collisions that would otherwise slow us down by introducing phantom matches. + * - The hashTable is effectively split into groups or "rows" of 15 or 31 entries of U32, and the index determines * which row to insert into. - * - Determine the correct position within the row to insert the entry into. Each row of 16 or 32 can - * be considered as a circular buffer with a "head" index that resides in the tagTable. - * - Also insert the "tag" into the equivalent row and position in the tagTable. - * - Note: The tagTable has 17 or 33 1-byte entries per row, due to 16 or 32 tags, and 1 "head" entry. - * The 17 or 33 entry rows are spaced out to occur every 32 or 64 bytes, respectively, - * for alignment/performance reasons, leaving some bytes unused. - * - Use SIMD to efficiently compare the tags in the tagTable to the 1-byte "short hash" and + * - Determine the correct position within the row to insert the entry into. Each row of 15 or 31 can + * be considered as a circular buffer with a "head" index that resides in the tagTable (overall 16 or 32 bytes + * per row). + * - Use SIMD to efficiently compare the tags in the tagTable to the 1-byte tag calculated for the position and * generate a bitfield that we can cycle through to check the collisions in the hash table. * - Pick the longest match. + * - Insert the tag into the equivalent row and position in the tagTable. */ FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR size_t ZSTD_RowFindBestMatch( - ZSTD_matchState_t* ms, + ZSTD_MatchState_t* ms, const BYTE* const ip, const BYTE* const iLimit, size_t* offsetPtr, const U32 mls, const ZSTD_dictMode_e dictMode, const U32 rowLog) { U32* const hashTable = ms->hashTable; - U16* const tagTable = ms->tagTable; + BYTE* const tagTable = ms->tagTable; U32* const hashCache = ms->hashCache; const U32 hashLog = ms->rowHashLog; const ZSTD_compressionParameters* const cParams = &ms->cParams; @@ -1143,11 +1165,14 @@ size_t ZSTD_RowFindBestMatch( const U32 rowEntries = (1U << rowLog); const U32 rowMask = rowEntries - 1; const U32 cappedSearchLog = MIN(cParams->searchLog, rowLog); /* nb of searches is capped at nb entries per row */ + const U32 groupWidth = ZSTD_row_matchMaskGroupWidth(rowEntries); + const U64 hashSalt = ms->hashSalt; U32 nbAttempts = 1U << cappedSearchLog; size_t ml=4-1; + U32 hash; /* DMS/DDS variables that may be referenced laster */ - const ZSTD_matchState_t* const dms = ms->dictMatchState; + const ZSTD_MatchState_t* const dms = ms->dictMatchState; /* Initialize the following variables to satisfy static analyzer */ size_t ddsIdx = 0; @@ -1168,7 +1193,7 @@ size_t ZSTD_RowFindBestMatch( if (dictMode == ZSTD_dictMatchState) { /* Prefetch DMS rows */ U32* const dmsHashTable = dms->hashTable; - U16* const dmsTagTable = dms->tagTable; + BYTE* const dmsTagTable = dms->tagTable; U32 const dmsHash = (U32)ZSTD_hashPtr(ip, dms->rowHashLog + ZSTD_ROW_HASH_TAG_BITS, mls); U32 const dmsRelRow = (dmsHash >> ZSTD_ROW_HASH_TAG_BITS) << rowLog; dmsTag = dmsHash & ZSTD_ROW_HASH_TAG_MASK; @@ -1178,23 +1203,34 @@ size_t ZSTD_RowFindBestMatch( } /* Update the hashTable and tagTable up to (but not including) ip */ - ZSTD_row_update_internal(ms, ip, mls, rowLog, rowMask, 1 /* useCache */); + if (!ms->lazySkipping) { + ZSTD_row_update_internal(ms, ip, mls, rowLog, rowMask, 1 /* useCache */); + hash = ZSTD_row_nextCachedHash(hashCache, hashTable, tagTable, base, curr, hashLog, rowLog, mls, hashSalt); + } else { + /* Stop inserting every position when in the lazy skipping mode. + * The hash cache is also not kept up to date in this mode. + */ + hash = (U32)ZSTD_hashPtrSalted(ip, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls, hashSalt); + ms->nextToUpdate = curr; + } + ms->hashSaltEntropy += hash; /* collect salt entropy */ + { /* Get the hash for ip, compute the appropriate row */ - U32 const hash = ZSTD_row_nextCachedHash(hashCache, hashTable, tagTable, base, curr, hashLog, rowLog, mls); U32 const relRow = (hash >> ZSTD_ROW_HASH_TAG_BITS) << rowLog; U32 const tag = hash & ZSTD_ROW_HASH_TAG_MASK; U32* const row = hashTable + relRow; BYTE* tagRow = (BYTE*)(tagTable + relRow); - U32 const head = *tagRow & rowMask; + U32 const headGrouped = (*tagRow & rowMask) * groupWidth; U32 matchBuffer[ZSTD_ROW_HASH_MAX_ENTRIES]; size_t numMatches = 0; size_t currMatch = 0; - ZSTD_VecMask matches = ZSTD_row_getMatchMask(tagRow, (BYTE)tag, head, rowEntries); + ZSTD_VecMask matches = ZSTD_row_getMatchMask(tagRow, (BYTE)tag, headGrouped, rowEntries); /* Cycle through the matches and prefetch */ - for (; (matches > 0) && (nbAttempts > 0); --nbAttempts, matches &= (matches - 1)) { - U32 const matchPos = (head + ZSTD_VecMask_next(matches)) & rowMask; + for (; (matches > 0) && (nbAttempts > 0); matches &= (matches - 1)) { + U32 const matchPos = ((headGrouped + ZSTD_VecMask_next(matches)) / groupWidth) & rowMask; U32 const matchIndex = row[matchPos]; + if(matchPos == 0) continue; assert(numMatches < rowEntries); if (matchIndex < lowLimit) break; @@ -1204,13 +1240,14 @@ size_t ZSTD_RowFindBestMatch( PREFETCH_L1(dictBase + matchIndex); } matchBuffer[numMatches++] = matchIndex; + --nbAttempts; } /* Speed opt: insert current byte into hashtable too. This allows us to avoid one iteration of the loop in ZSTD_row_update_internal() at the next search. */ { U32 const pos = ZSTD_row_nextIndex(tagRow, rowMask); - tagRow[pos + ZSTD_ROW_HASH_TAG_OFFSET] = (BYTE)tag; + tagRow[pos] = (BYTE)tag; row[pos] = ms->nextToUpdate++; } @@ -1224,7 +1261,8 @@ size_t ZSTD_RowFindBestMatch( if ((dictMode != ZSTD_extDict) || matchIndex >= dictLimit) { const BYTE* const match = base + matchIndex; assert(matchIndex >= dictLimit); /* ensures this is true if dictMode != ZSTD_extDict */ - if (match[ml] == ip[ml]) /* potentially better */ + /* read 4B starting from (match + ml + 1 - sizeof(U32)) */ + if (MEM_read32(match + ml - 3) == MEM_read32(ip + ml - 3)) /* potentially better */ currentMl = ZSTD_count(ip, match, iLimit); } else { const BYTE* const match = dictBase + matchIndex; @@ -1236,7 +1274,7 @@ size_t ZSTD_RowFindBestMatch( /* Save best solution */ if (currentMl > ml) { ml = currentMl; - *offsetPtr = STORE_OFFSET(curr - matchIndex); + *offsetPtr = OFFSET_TO_OFFBASE(curr - matchIndex); if (ip+currentMl == iLimit) break; /* best possible, avoids read overflow on next attempt */ } } @@ -1254,19 +1292,21 @@ size_t ZSTD_RowFindBestMatch( const U32 dmsSize = (U32)(dmsEnd - dmsBase); const U32 dmsIndexDelta = dictLimit - dmsSize; - { U32 const head = *dmsTagRow & rowMask; + { U32 const headGrouped = (*dmsTagRow & rowMask) * groupWidth; U32 matchBuffer[ZSTD_ROW_HASH_MAX_ENTRIES]; size_t numMatches = 0; size_t currMatch = 0; - ZSTD_VecMask matches = ZSTD_row_getMatchMask(dmsTagRow, (BYTE)dmsTag, head, rowEntries); + ZSTD_VecMask matches = ZSTD_row_getMatchMask(dmsTagRow, (BYTE)dmsTag, headGrouped, rowEntries); - for (; (matches > 0) && (nbAttempts > 0); --nbAttempts, matches &= (matches - 1)) { - U32 const matchPos = (head + ZSTD_VecMask_next(matches)) & rowMask; + for (; (matches > 0) && (nbAttempts > 0); matches &= (matches - 1)) { + U32 const matchPos = ((headGrouped + ZSTD_VecMask_next(matches)) / groupWidth) & rowMask; U32 const matchIndex = dmsRow[matchPos]; + if(matchPos == 0) continue; if (matchIndex < dmsLowestIndex) break; PREFETCH_L1(dmsBase + matchIndex); matchBuffer[numMatches++] = matchIndex; + --nbAttempts; } /* Return the longest match */ @@ -1285,7 +1325,7 @@ size_t ZSTD_RowFindBestMatch( if (currentMl > ml) { ml = currentMl; assert(curr > matchIndex + dmsIndexDelta); - *offsetPtr = STORE_OFFSET(curr - (matchIndex + dmsIndexDelta)); + *offsetPtr = OFFSET_TO_OFFBASE(curr - (matchIndex + dmsIndexDelta)); if (ip+currentMl == iLimit) break; } } @@ -1301,7 +1341,7 @@ size_t ZSTD_RowFindBestMatch( * ZSTD_searchMax() dispatches to the correct implementation function. * * TODO: The start of the search function involves loading and calculating a - * bunch of constants from the ZSTD_matchState_t. These computations could be + * bunch of constants from the ZSTD_MatchState_t. These computations could be * done in an initialization function, and saved somewhere in the match state. * Then we could pass a pointer to the saved state instead of the match state, * and avoid duplicate computations. @@ -1325,7 +1365,7 @@ size_t ZSTD_RowFindBestMatch( #define GEN_ZSTD_BT_SEARCH_FN(dictMode, mls) \ ZSTD_SEARCH_FN_ATTRS size_t ZSTD_BT_SEARCH_FN(dictMode, mls)( \ - ZSTD_matchState_t* ms, \ + ZSTD_MatchState_t* ms, \ const BYTE* ip, const BYTE* const iLimit, \ size_t* offBasePtr) \ { \ @@ -1335,7 +1375,7 @@ size_t ZSTD_RowFindBestMatch( #define GEN_ZSTD_HC_SEARCH_FN(dictMode, mls) \ ZSTD_SEARCH_FN_ATTRS size_t ZSTD_HC_SEARCH_FN(dictMode, mls)( \ - ZSTD_matchState_t* ms, \ + ZSTD_MatchState_t* ms, \ const BYTE* ip, const BYTE* const iLimit, \ size_t* offsetPtr) \ { \ @@ -1345,7 +1385,7 @@ size_t ZSTD_RowFindBestMatch( #define GEN_ZSTD_ROW_SEARCH_FN(dictMode, mls, rowLog) \ ZSTD_SEARCH_FN_ATTRS size_t ZSTD_ROW_SEARCH_FN(dictMode, mls, rowLog)( \ - ZSTD_matchState_t* ms, \ + ZSTD_MatchState_t* ms, \ const BYTE* ip, const BYTE* const iLimit, \ size_t* offsetPtr) \ { \ @@ -1446,7 +1486,7 @@ typedef enum { search_hashChain=0, search_binaryTree=1, search_rowHash=2 } searc * If a match is found its offset is stored in @p offsetPtr. */ FORCE_INLINE_TEMPLATE size_t ZSTD_searchMax( - ZSTD_matchState_t* ms, + ZSTD_MatchState_t* ms, const BYTE* ip, const BYTE* iend, size_t* offsetPtr, @@ -1472,9 +1512,10 @@ FORCE_INLINE_TEMPLATE size_t ZSTD_searchMax( * Common parser - lazy strategy *********************************/ -FORCE_INLINE_TEMPLATE size_t -ZSTD_compressBlock_lazy_generic( - ZSTD_matchState_t* ms, seqStore_t* seqStore, +FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +size_t ZSTD_compressBlock_lazy_generic( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], const void* src, size_t srcSize, const searchMethod_e searchMethod, const U32 depth, @@ -1491,12 +1532,13 @@ ZSTD_compressBlock_lazy_generic( const U32 mls = BOUNDED(4, ms->cParams.minMatch, 6); const U32 rowLog = BOUNDED(4, ms->cParams.searchLog, 6); - U32 offset_1 = rep[0], offset_2 = rep[1], savedOffset=0; + U32 offset_1 = rep[0], offset_2 = rep[1]; + U32 offsetSaved1 = 0, offsetSaved2 = 0; const int isDMS = dictMode == ZSTD_dictMatchState; const int isDDS = dictMode == ZSTD_dedicatedDictSearch; const int isDxS = isDMS || isDDS; - const ZSTD_matchState_t* const dms = ms->dictMatchState; + const ZSTD_MatchState_t* const dms = ms->dictMatchState; const U32 dictLowestIndex = isDxS ? dms->window.dictLimit : 0; const BYTE* const dictBase = isDxS ? dms->window.base : NULL; const BYTE* const dictLowest = isDxS ? dictBase + dictLowestIndex : NULL; @@ -1512,8 +1554,8 @@ ZSTD_compressBlock_lazy_generic( U32 const curr = (U32)(ip - base); U32 const windowLow = ZSTD_getLowestPrefixIndex(ms, curr, ms->cParams.windowLog); U32 const maxRep = curr - windowLow; - if (offset_2 > maxRep) savedOffset = offset_2, offset_2 = 0; - if (offset_1 > maxRep) savedOffset = offset_1, offset_1 = 0; + if (offset_2 > maxRep) offsetSaved2 = offset_2, offset_2 = 0; + if (offset_1 > maxRep) offsetSaved1 = offset_1, offset_1 = 0; } if (isDxS) { /* dictMatchState repCode checks don't currently handle repCode == 0 @@ -1522,10 +1564,11 @@ ZSTD_compressBlock_lazy_generic( assert(offset_2 <= dictAndPrefixLength); } + /* Reset the lazy skipping state */ + ms->lazySkipping = 0; + if (searchMethod == search_rowHash) { - ZSTD_row_fillHashCache(ms, base, rowLog, - MIN(ms->cParams.minMatch, 6 /* mls caps out at 6 */), - ms->nextToUpdate, ilimit); + ZSTD_row_fillHashCache(ms, base, rowLog, mls, ms->nextToUpdate, ilimit); } /* Match Loop */ @@ -1537,7 +1580,7 @@ ZSTD_compressBlock_lazy_generic( #endif while (ip < ilimit) { size_t matchLength=0; - size_t offcode=STORE_REPCODE_1; + size_t offBase = REPCODE1_TO_OFFBASE; const BYTE* start=ip+1; DEBUGLOG(7, "search baseline (depth 0)"); @@ -1548,7 +1591,7 @@ ZSTD_compressBlock_lazy_generic( && repIndex < prefixLowestIndex) ? dictBase + (repIndex - dictIndexDelta) : base + repIndex; - if (((U32)((prefixLowestIndex-1) - repIndex) >= 3 /* intentional underflow */) + if ((ZSTD_index_overlap_check(prefixLowestIndex, repIndex)) && (MEM_read32(repMatch) == MEM_read32(ip+1)) ) { const BYTE* repMatchEnd = repIndex < prefixLowestIndex ? dictEnd : iend; matchLength = ZSTD_count_2segments(ip+1+4, repMatch+4, iend, repMatchEnd, prefixLowest) + 4; @@ -1562,14 +1605,23 @@ ZSTD_compressBlock_lazy_generic( } /* first search (depth 0) */ - { size_t offsetFound = 999999999; - size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &offsetFound, mls, rowLog, searchMethod, dictMode); + { size_t offbaseFound = 999999999; + size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &offbaseFound, mls, rowLog, searchMethod, dictMode); if (ml2 > matchLength) - matchLength = ml2, start = ip, offcode=offsetFound; + matchLength = ml2, start = ip, offBase = offbaseFound; } if (matchLength < 4) { - ip += ((ip-anchor) >> kSearchStrength) + 1; /* jump faster over incompressible sections */ + size_t const step = ((size_t)(ip-anchor) >> kSearchStrength) + 1; /* jump faster over incompressible sections */; + ip += step; + /* Enter the lazy skipping mode once we are skipping more than 8 bytes at a time. + * In this mode we stop inserting every position into our tables, and only insert + * positions that we search, which is one in step positions. + * The exact cutoff is flexible, I've just chosen a number that is reasonably high, + * so we minimize the compression ratio loss in "normal" scenarios. This mode gets + * triggered once we've gone 2KB without finding any matches. + */ + ms->lazySkipping = step > kLazySkippingStep; continue; } @@ -1579,34 +1631,34 @@ ZSTD_compressBlock_lazy_generic( DEBUGLOG(7, "search depth 1"); ip ++; if ( (dictMode == ZSTD_noDict) - && (offcode) && ((offset_1>0) & (MEM_read32(ip) == MEM_read32(ip - offset_1)))) { + && (offBase) && ((offset_1>0) & (MEM_read32(ip) == MEM_read32(ip - offset_1)))) { size_t const mlRep = ZSTD_count(ip+4, ip+4-offset_1, iend) + 4; int const gain2 = (int)(mlRep * 3); - int const gain1 = (int)(matchLength*3 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 1); + int const gain1 = (int)(matchLength*3 - ZSTD_highbit32((U32)offBase) + 1); if ((mlRep >= 4) && (gain2 > gain1)) - matchLength = mlRep, offcode = STORE_REPCODE_1, start = ip; + matchLength = mlRep, offBase = REPCODE1_TO_OFFBASE, start = ip; } if (isDxS) { const U32 repIndex = (U32)(ip - base) - offset_1; const BYTE* repMatch = repIndex < prefixLowestIndex ? dictBase + (repIndex - dictIndexDelta) : base + repIndex; - if (((U32)((prefixLowestIndex-1) - repIndex) >= 3 /* intentional underflow */) + if ((ZSTD_index_overlap_check(prefixLowestIndex, repIndex)) && (MEM_read32(repMatch) == MEM_read32(ip)) ) { const BYTE* repMatchEnd = repIndex < prefixLowestIndex ? dictEnd : iend; size_t const mlRep = ZSTD_count_2segments(ip+4, repMatch+4, iend, repMatchEnd, prefixLowest) + 4; int const gain2 = (int)(mlRep * 3); - int const gain1 = (int)(matchLength*3 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 1); + int const gain1 = (int)(matchLength*3 - ZSTD_highbit32((U32)offBase) + 1); if ((mlRep >= 4) && (gain2 > gain1)) - matchLength = mlRep, offcode = STORE_REPCODE_1, start = ip; + matchLength = mlRep, offBase = REPCODE1_TO_OFFBASE, start = ip; } } - { size_t offset2=999999999; - size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &offset2, mls, rowLog, searchMethod, dictMode); - int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offset2))); /* raw approx */ - int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 4); + { size_t ofbCandidate=999999999; + size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &ofbCandidate, mls, rowLog, searchMethod, dictMode); + int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)ofbCandidate)); /* raw approx */ + int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)offBase) + 4); if ((ml2 >= 4) && (gain2 > gain1)) { - matchLength = ml2, offcode = offset2, start = ip; + matchLength = ml2, offBase = ofbCandidate, start = ip; continue; /* search a better one */ } } @@ -1615,34 +1667,34 @@ ZSTD_compressBlock_lazy_generic( DEBUGLOG(7, "search depth 2"); ip ++; if ( (dictMode == ZSTD_noDict) - && (offcode) && ((offset_1>0) & (MEM_read32(ip) == MEM_read32(ip - offset_1)))) { + && (offBase) && ((offset_1>0) & (MEM_read32(ip) == MEM_read32(ip - offset_1)))) { size_t const mlRep = ZSTD_count(ip+4, ip+4-offset_1, iend) + 4; int const gain2 = (int)(mlRep * 4); - int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 1); + int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)offBase) + 1); if ((mlRep >= 4) && (gain2 > gain1)) - matchLength = mlRep, offcode = STORE_REPCODE_1, start = ip; + matchLength = mlRep, offBase = REPCODE1_TO_OFFBASE, start = ip; } if (isDxS) { const U32 repIndex = (U32)(ip - base) - offset_1; const BYTE* repMatch = repIndex < prefixLowestIndex ? dictBase + (repIndex - dictIndexDelta) : base + repIndex; - if (((U32)((prefixLowestIndex-1) - repIndex) >= 3 /* intentional underflow */) + if ((ZSTD_index_overlap_check(prefixLowestIndex, repIndex)) && (MEM_read32(repMatch) == MEM_read32(ip)) ) { const BYTE* repMatchEnd = repIndex < prefixLowestIndex ? dictEnd : iend; size_t const mlRep = ZSTD_count_2segments(ip+4, repMatch+4, iend, repMatchEnd, prefixLowest) + 4; int const gain2 = (int)(mlRep * 4); - int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 1); + int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)offBase) + 1); if ((mlRep >= 4) && (gain2 > gain1)) - matchLength = mlRep, offcode = STORE_REPCODE_1, start = ip; + matchLength = mlRep, offBase = REPCODE1_TO_OFFBASE, start = ip; } } - { size_t offset2=999999999; - size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &offset2, mls, rowLog, searchMethod, dictMode); - int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offset2))); /* raw approx */ - int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 7); + { size_t ofbCandidate=999999999; + size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &ofbCandidate, mls, rowLog, searchMethod, dictMode); + int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)ofbCandidate)); /* raw approx */ + int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)offBase) + 7); if ((ml2 >= 4) && (gain2 > gain1)) { - matchLength = ml2, offcode = offset2, start = ip; + matchLength = ml2, offBase = ofbCandidate, start = ip; continue; } } } break; /* nothing found : store previous solution */ @@ -1653,26 +1705,33 @@ ZSTD_compressBlock_lazy_generic( * notably if `value` is unsigned, resulting in a large positive `-value`. */ /* catch up */ - if (STORED_IS_OFFSET(offcode)) { + if (OFFBASE_IS_OFFSET(offBase)) { if (dictMode == ZSTD_noDict) { - while ( ((start > anchor) & (start - STORED_OFFSET(offcode) > prefixLowest)) - && (start[-1] == (start-STORED_OFFSET(offcode))[-1]) ) /* only search for offset within prefix */ + while ( ((start > anchor) & (start - OFFBASE_TO_OFFSET(offBase) > prefixLowest)) + && (start[-1] == (start-OFFBASE_TO_OFFSET(offBase))[-1]) ) /* only search for offset within prefix */ { start--; matchLength++; } } if (isDxS) { - U32 const matchIndex = (U32)((size_t)(start-base) - STORED_OFFSET(offcode)); + U32 const matchIndex = (U32)((size_t)(start-base) - OFFBASE_TO_OFFSET(offBase)); const BYTE* match = (matchIndex < prefixLowestIndex) ? dictBase + matchIndex - dictIndexDelta : base + matchIndex; const BYTE* const mStart = (matchIndex < prefixLowestIndex) ? dictLowest : prefixLowest; while ((start>anchor) && (match>mStart) && (start[-1] == match[-1])) { start--; match--; matchLength++; } /* catch up */ } - offset_2 = offset_1; offset_1 = (U32)STORED_OFFSET(offcode); + offset_2 = offset_1; offset_1 = (U32)OFFBASE_TO_OFFSET(offBase); } /* store sequence */ _storeSequence: { size_t const litLength = (size_t)(start - anchor); - ZSTD_storeSeq(seqStore, litLength, anchor, iend, (U32)offcode, matchLength); + ZSTD_storeSeq(seqStore, litLength, anchor, iend, (U32)offBase, matchLength); anchor = ip = start + matchLength; } + if (ms->lazySkipping) { + /* We've found a match, disable lazy skipping mode, and refill the hash cache. */ + if (searchMethod == search_rowHash) { + ZSTD_row_fillHashCache(ms, base, rowLog, mls, ms->nextToUpdate, ilimit); + } + ms->lazySkipping = 0; + } /* check immediate repcode */ if (isDxS) { @@ -1682,12 +1741,12 @@ ZSTD_compressBlock_lazy_generic( const BYTE* repMatch = repIndex < prefixLowestIndex ? dictBase - dictIndexDelta + repIndex : base + repIndex; - if ( ((U32)((prefixLowestIndex-1) - (U32)repIndex) >= 3 /* intentional overflow */) + if ( (ZSTD_index_overlap_check(prefixLowestIndex, repIndex)) && (MEM_read32(repMatch) == MEM_read32(ip)) ) { const BYTE* const repEnd2 = repIndex < prefixLowestIndex ? dictEnd : iend; matchLength = ZSTD_count_2segments(ip+4, repMatch+4, iend, repEnd2, prefixLowest) + 4; - offcode = offset_2; offset_2 = offset_1; offset_1 = (U32)offcode; /* swap offset_2 <=> offset_1 */ - ZSTD_storeSeq(seqStore, 0, anchor, iend, STORE_REPCODE_1, matchLength); + offBase = offset_2; offset_2 = offset_1; offset_1 = (U32)offBase; /* swap offset_2 <=> offset_1 */ + ZSTD_storeSeq(seqStore, 0, anchor, iend, REPCODE1_TO_OFFBASE, matchLength); ip += matchLength; anchor = ip; continue; @@ -1701,168 +1760,183 @@ ZSTD_compressBlock_lazy_generic( && (MEM_read32(ip) == MEM_read32(ip - offset_2)) ) { /* store sequence */ matchLength = ZSTD_count(ip+4, ip+4-offset_2, iend) + 4; - offcode = offset_2; offset_2 = offset_1; offset_1 = (U32)offcode; /* swap repcodes */ - ZSTD_storeSeq(seqStore, 0, anchor, iend, STORE_REPCODE_1, matchLength); + offBase = offset_2; offset_2 = offset_1; offset_1 = (U32)offBase; /* swap repcodes */ + ZSTD_storeSeq(seqStore, 0, anchor, iend, REPCODE1_TO_OFFBASE, matchLength); ip += matchLength; anchor = ip; continue; /* faster when present ... (?) */ } } } - /* Save reps for next block */ - rep[0] = offset_1 ? offset_1 : savedOffset; - rep[1] = offset_2 ? offset_2 : savedOffset; + /* If offset_1 started invalid (offsetSaved1 != 0) and became valid (offset_1 != 0), + * rotate saved offsets. See comment in ZSTD_compressBlock_fast_noDict for more context. */ + offsetSaved2 = ((offsetSaved1 != 0) && (offset_1 != 0)) ? offsetSaved1 : offsetSaved2; + + /* save reps for next block */ + rep[0] = offset_1 ? offset_1 : offsetSaved1; + rep[1] = offset_2 ? offset_2 : offsetSaved2; /* Return the last literals size */ return (size_t)(iend - anchor); } +#endif /* build exclusions */ -size_t ZSTD_compressBlock_btlazy2( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +#ifndef ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_greedy( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_binaryTree, 2, ZSTD_noDict); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 0, ZSTD_noDict); } -size_t ZSTD_compressBlock_lazy2( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_greedy_dictMatchState( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2, ZSTD_noDict); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 0, ZSTD_dictMatchState); } -size_t ZSTD_compressBlock_lazy( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_greedy_dedicatedDictSearch( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1, ZSTD_noDict); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 0, ZSTD_dedicatedDictSearch); } -size_t ZSTD_compressBlock_greedy( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_greedy_row( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 0, ZSTD_noDict); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0, ZSTD_noDict); } -size_t ZSTD_compressBlock_btlazy2_dictMatchState( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_greedy_dictMatchState_row( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_binaryTree, 2, ZSTD_dictMatchState); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0, ZSTD_dictMatchState); } -size_t ZSTD_compressBlock_lazy2_dictMatchState( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_greedy_dedicatedDictSearch_row( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2, ZSTD_dictMatchState); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0, ZSTD_dedicatedDictSearch); } +#endif -size_t ZSTD_compressBlock_lazy_dictMatchState( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +#ifndef ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_lazy( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1, ZSTD_dictMatchState); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1, ZSTD_noDict); } -size_t ZSTD_compressBlock_greedy_dictMatchState( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_lazy_dictMatchState( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 0, ZSTD_dictMatchState); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1, ZSTD_dictMatchState); } - -size_t ZSTD_compressBlock_lazy2_dedicatedDictSearch( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_lazy_dedicatedDictSearch( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2, ZSTD_dedicatedDictSearch); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1, ZSTD_dedicatedDictSearch); } -size_t ZSTD_compressBlock_lazy_dedicatedDictSearch( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_lazy_row( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1, ZSTD_dedicatedDictSearch); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1, ZSTD_noDict); } -size_t ZSTD_compressBlock_greedy_dedicatedDictSearch( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_lazy_dictMatchState_row( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 0, ZSTD_dedicatedDictSearch); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1, ZSTD_dictMatchState); } -/* Row-based matchfinder */ -size_t ZSTD_compressBlock_lazy2_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_lazy_dedicatedDictSearch_row( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 2, ZSTD_noDict); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1, ZSTD_dedicatedDictSearch); } +#endif -size_t ZSTD_compressBlock_lazy_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +#ifndef ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_lazy2( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1, ZSTD_noDict); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2, ZSTD_noDict); } -size_t ZSTD_compressBlock_greedy_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_lazy2_dictMatchState( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0, ZSTD_noDict); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2, ZSTD_dictMatchState); } -size_t ZSTD_compressBlock_lazy2_dictMatchState_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_lazy2_dedicatedDictSearch( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 2, ZSTD_dictMatchState); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2, ZSTD_dedicatedDictSearch); } -size_t ZSTD_compressBlock_lazy_dictMatchState_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_lazy2_row( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1, ZSTD_dictMatchState); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 2, ZSTD_noDict); } -size_t ZSTD_compressBlock_greedy_dictMatchState_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_lazy2_dictMatchState_row( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0, ZSTD_dictMatchState); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 2, ZSTD_dictMatchState); } - size_t ZSTD_compressBlock_lazy2_dedicatedDictSearch_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 2, ZSTD_dedicatedDictSearch); } +#endif -size_t ZSTD_compressBlock_lazy_dedicatedDictSearch_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +#ifndef ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_btlazy2( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1, ZSTD_dedicatedDictSearch); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_binaryTree, 2, ZSTD_noDict); } -size_t ZSTD_compressBlock_greedy_dedicatedDictSearch_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_btlazy2_dictMatchState( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0, ZSTD_dedicatedDictSearch); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_binaryTree, 2, ZSTD_dictMatchState); } +#endif +#if !defined(ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR) FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR size_t ZSTD_compressBlock_lazy_extDict_generic( - ZSTD_matchState_t* ms, seqStore_t* seqStore, + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], const void* src, size_t srcSize, const searchMethod_e searchMethod, const U32 depth) @@ -1886,12 +1960,13 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( DEBUGLOG(5, "ZSTD_compressBlock_lazy_extDict_generic (searchFunc=%u)", (U32)searchMethod); + /* Reset the lazy skipping state */ + ms->lazySkipping = 0; + /* init */ ip += (ip == prefixStart); if (searchMethod == search_rowHash) { - ZSTD_row_fillHashCache(ms, base, rowLog, - MIN(ms->cParams.minMatch, 6 /* mls caps out at 6 */), - ms->nextToUpdate, ilimit); + ZSTD_row_fillHashCache(ms, base, rowLog, mls, ms->nextToUpdate, ilimit); } /* Match Loop */ @@ -1903,7 +1978,7 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( #endif while (ip < ilimit) { size_t matchLength=0; - size_t offcode=STORE_REPCODE_1; + size_t offBase = REPCODE1_TO_OFFBASE; const BYTE* start=ip+1; U32 curr = (U32)(ip-base); @@ -1912,7 +1987,7 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( const U32 repIndex = (U32)(curr+1 - offset_1); const BYTE* const repBase = repIndex < dictLimit ? dictBase : base; const BYTE* const repMatch = repBase + repIndex; - if ( ((U32)((dictLimit-1) - repIndex) >= 3) /* intentional overflow */ + if ( (ZSTD_index_overlap_check(dictLimit, repIndex)) & (offset_1 <= curr+1 - windowLow) ) /* note: we are searching at curr+1 */ if (MEM_read32(ip+1) == MEM_read32(repMatch)) { /* repcode detected we should take it */ @@ -1922,14 +1997,23 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( } } /* first search (depth 0) */ - { size_t offsetFound = 999999999; - size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &offsetFound, mls, rowLog, searchMethod, ZSTD_extDict); + { size_t ofbCandidate = 999999999; + size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &ofbCandidate, mls, rowLog, searchMethod, ZSTD_extDict); if (ml2 > matchLength) - matchLength = ml2, start = ip, offcode=offsetFound; + matchLength = ml2, start = ip, offBase = ofbCandidate; } if (matchLength < 4) { - ip += ((ip-anchor) >> kSearchStrength) + 1; /* jump faster over incompressible sections */ + size_t const step = ((size_t)(ip-anchor) >> kSearchStrength); + ip += step + 1; /* jump faster over incompressible sections */ + /* Enter the lazy skipping mode once we are skipping more than 8 bytes at a time. + * In this mode we stop inserting every position into our tables, and only insert + * positions that we search, which is one in step positions. + * The exact cutoff is flexible, I've just chosen a number that is reasonably high, + * so we minimize the compression ratio loss in "normal" scenarios. This mode gets + * triggered once we've gone 2KB without finding any matches. + */ + ms->lazySkipping = step > kLazySkippingStep; continue; } @@ -1939,30 +2023,30 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( ip ++; curr++; /* check repCode */ - if (offcode) { + if (offBase) { const U32 windowLow = ZSTD_getLowestMatchIndex(ms, curr, windowLog); const U32 repIndex = (U32)(curr - offset_1); const BYTE* const repBase = repIndex < dictLimit ? dictBase : base; const BYTE* const repMatch = repBase + repIndex; - if ( ((U32)((dictLimit-1) - repIndex) >= 3) /* intentional overflow : do not test positions overlapping 2 memory segments */ + if ( (ZSTD_index_overlap_check(dictLimit, repIndex)) & (offset_1 <= curr - windowLow) ) /* equivalent to `curr > repIndex >= windowLow` */ if (MEM_read32(ip) == MEM_read32(repMatch)) { /* repcode detected */ const BYTE* const repEnd = repIndex < dictLimit ? dictEnd : iend; size_t const repLength = ZSTD_count_2segments(ip+4, repMatch+4, iend, repEnd, prefixStart) + 4; int const gain2 = (int)(repLength * 3); - int const gain1 = (int)(matchLength*3 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 1); + int const gain1 = (int)(matchLength*3 - ZSTD_highbit32((U32)offBase) + 1); if ((repLength >= 4) && (gain2 > gain1)) - matchLength = repLength, offcode = STORE_REPCODE_1, start = ip; + matchLength = repLength, offBase = REPCODE1_TO_OFFBASE, start = ip; } } /* search match, depth 1 */ - { size_t offset2=999999999; - size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &offset2, mls, rowLog, searchMethod, ZSTD_extDict); - int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offset2))); /* raw approx */ - int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 4); + { size_t ofbCandidate = 999999999; + size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &ofbCandidate, mls, rowLog, searchMethod, ZSTD_extDict); + int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)ofbCandidate)); /* raw approx */ + int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)offBase) + 4); if ((ml2 >= 4) && (gain2 > gain1)) { - matchLength = ml2, offcode = offset2, start = ip; + matchLength = ml2, offBase = ofbCandidate, start = ip; continue; /* search a better one */ } } @@ -1971,50 +2055,57 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( ip ++; curr++; /* check repCode */ - if (offcode) { + if (offBase) { const U32 windowLow = ZSTD_getLowestMatchIndex(ms, curr, windowLog); const U32 repIndex = (U32)(curr - offset_1); const BYTE* const repBase = repIndex < dictLimit ? dictBase : base; const BYTE* const repMatch = repBase + repIndex; - if ( ((U32)((dictLimit-1) - repIndex) >= 3) /* intentional overflow : do not test positions overlapping 2 memory segments */ + if ( (ZSTD_index_overlap_check(dictLimit, repIndex)) & (offset_1 <= curr - windowLow) ) /* equivalent to `curr > repIndex >= windowLow` */ if (MEM_read32(ip) == MEM_read32(repMatch)) { /* repcode detected */ const BYTE* const repEnd = repIndex < dictLimit ? dictEnd : iend; size_t const repLength = ZSTD_count_2segments(ip+4, repMatch+4, iend, repEnd, prefixStart) + 4; int const gain2 = (int)(repLength * 4); - int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 1); + int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)offBase) + 1); if ((repLength >= 4) && (gain2 > gain1)) - matchLength = repLength, offcode = STORE_REPCODE_1, start = ip; + matchLength = repLength, offBase = REPCODE1_TO_OFFBASE, start = ip; } } /* search match, depth 2 */ - { size_t offset2=999999999; - size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &offset2, mls, rowLog, searchMethod, ZSTD_extDict); - int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offset2))); /* raw approx */ - int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 7); + { size_t ofbCandidate = 999999999; + size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &ofbCandidate, mls, rowLog, searchMethod, ZSTD_extDict); + int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)ofbCandidate)); /* raw approx */ + int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)offBase) + 7); if ((ml2 >= 4) && (gain2 > gain1)) { - matchLength = ml2, offcode = offset2, start = ip; + matchLength = ml2, offBase = ofbCandidate, start = ip; continue; } } } break; /* nothing found : store previous solution */ } /* catch up */ - if (STORED_IS_OFFSET(offcode)) { - U32 const matchIndex = (U32)((size_t)(start-base) - STORED_OFFSET(offcode)); + if (OFFBASE_IS_OFFSET(offBase)) { + U32 const matchIndex = (U32)((size_t)(start-base) - OFFBASE_TO_OFFSET(offBase)); const BYTE* match = (matchIndex < dictLimit) ? dictBase + matchIndex : base + matchIndex; const BYTE* const mStart = (matchIndex < dictLimit) ? dictStart : prefixStart; while ((start>anchor) && (match>mStart) && (start[-1] == match[-1])) { start--; match--; matchLength++; } /* catch up */ - offset_2 = offset_1; offset_1 = (U32)STORED_OFFSET(offcode); + offset_2 = offset_1; offset_1 = (U32)OFFBASE_TO_OFFSET(offBase); } /* store sequence */ _storeSequence: { size_t const litLength = (size_t)(start - anchor); - ZSTD_storeSeq(seqStore, litLength, anchor, iend, (U32)offcode, matchLength); + ZSTD_storeSeq(seqStore, litLength, anchor, iend, (U32)offBase, matchLength); anchor = ip = start + matchLength; } + if (ms->lazySkipping) { + /* We've found a match, disable lazy skipping mode, and refill the hash cache. */ + if (searchMethod == search_rowHash) { + ZSTD_row_fillHashCache(ms, base, rowLog, mls, ms->nextToUpdate, ilimit); + } + ms->lazySkipping = 0; + } /* check immediate repcode */ while (ip <= ilimit) { @@ -2023,14 +2114,14 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( const U32 repIndex = repCurrent - offset_2; const BYTE* const repBase = repIndex < dictLimit ? dictBase : base; const BYTE* const repMatch = repBase + repIndex; - if ( ((U32)((dictLimit-1) - repIndex) >= 3) /* intentional overflow : do not test positions overlapping 2 memory segments */ + if ( (ZSTD_index_overlap_check(dictLimit, repIndex)) & (offset_2 <= repCurrent - windowLow) ) /* equivalent to `curr > repIndex >= windowLow` */ if (MEM_read32(ip) == MEM_read32(repMatch)) { /* repcode detected we should take it */ const BYTE* const repEnd = repIndex < dictLimit ? dictEnd : iend; matchLength = ZSTD_count_2segments(ip+4, repMatch+4, iend, repEnd, prefixStart) + 4; - offcode = offset_2; offset_2 = offset_1; offset_1 = (U32)offcode; /* swap offset history */ - ZSTD_storeSeq(seqStore, 0, anchor, iend, STORE_REPCODE_1, matchLength); + offBase = offset_2; offset_2 = offset_1; offset_1 = (U32)offBase; /* swap offset history */ + ZSTD_storeSeq(seqStore, 0, anchor, iend, REPCODE1_TO_OFFBASE, matchLength); ip += matchLength; anchor = ip; continue; /* faster when present ... (?) */ @@ -2045,58 +2136,65 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( /* Return the last literals size */ return (size_t)(iend - anchor); } +#endif /* build exclusions */ - +#ifndef ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR size_t ZSTD_compressBlock_greedy_extDict( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 0); } -size_t ZSTD_compressBlock_lazy_extDict( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_greedy_extDict_row( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) - { - return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1); + return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0); } +#endif -size_t ZSTD_compressBlock_lazy2_extDict( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +#ifndef ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_lazy_extDict( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2); + return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1); } -size_t ZSTD_compressBlock_btlazy2_extDict( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_lazy_extDict_row( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_binaryTree, 2); + return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1); } +#endif -size_t ZSTD_compressBlock_greedy_extDict_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +#ifndef ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_lazy2_extDict( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) + { - return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0); + return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2); } -size_t ZSTD_compressBlock_lazy_extDict_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_lazy2_extDict_row( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) - { - return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1); + return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 2); } +#endif -size_t ZSTD_compressBlock_lazy2_extDict_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +#ifndef ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_btlazy2_extDict( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 2); + return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_binaryTree, 2); } +#endif diff --git a/lib/zstd/compress/zstd_lazy.h b/lib/zstd/compress/zstd_lazy.h index e5bdf4df8dde..987a036d8bde 100644 --- a/lib/zstd/compress/zstd_lazy.h +++ b/lib/zstd/compress/zstd_lazy.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -11,7 +12,6 @@ #ifndef ZSTD_LAZY_H #define ZSTD_LAZY_H - #include "zstd_compress_internal.h" /* @@ -22,98 +22,173 @@ */ #define ZSTD_LAZY_DDSS_BUCKET_LOG 2 -U32 ZSTD_insertAndFindFirstIndex(ZSTD_matchState_t* ms, const BYTE* ip); -void ZSTD_row_update(ZSTD_matchState_t* const ms, const BYTE* ip); +#define ZSTD_ROW_HASH_TAG_BITS 8 /* nb bits to use for the tag */ + +#if !defined(ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR) +U32 ZSTD_insertAndFindFirstIndex(ZSTD_MatchState_t* ms, const BYTE* ip); +void ZSTD_row_update(ZSTD_MatchState_t* const ms, const BYTE* ip); -void ZSTD_dedicatedDictSearch_lazy_loadDictionary(ZSTD_matchState_t* ms, const BYTE* const ip); +void ZSTD_dedicatedDictSearch_lazy_loadDictionary(ZSTD_MatchState_t* ms, const BYTE* const ip); void ZSTD_preserveUnsortedMark (U32* const table, U32 const size, U32 const reducerValue); /*! used in ZSTD_reduceIndex(). preemptively increase value of ZSTD_DUBT_UNSORTED_MARK */ +#endif -size_t ZSTD_compressBlock_btlazy2( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +#ifndef ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_greedy( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy2( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_greedy_row( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_greedy_dictMatchState( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_greedy( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_greedy_dictMatchState_row( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy2_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_greedy_dedicatedDictSearch( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_greedy_dedicatedDictSearch_row( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_greedy_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_greedy_extDict( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +size_t ZSTD_compressBlock_greedy_extDict_row( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_btlazy2_dictMatchState( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +#define ZSTD_COMPRESSBLOCK_GREEDY ZSTD_compressBlock_greedy +#define ZSTD_COMPRESSBLOCK_GREEDY_ROW ZSTD_compressBlock_greedy_row +#define ZSTD_COMPRESSBLOCK_GREEDY_DICTMATCHSTATE ZSTD_compressBlock_greedy_dictMatchState +#define ZSTD_COMPRESSBLOCK_GREEDY_DICTMATCHSTATE_ROW ZSTD_compressBlock_greedy_dictMatchState_row +#define ZSTD_COMPRESSBLOCK_GREEDY_DEDICATEDDICTSEARCH ZSTD_compressBlock_greedy_dedicatedDictSearch +#define ZSTD_COMPRESSBLOCK_GREEDY_DEDICATEDDICTSEARCH_ROW ZSTD_compressBlock_greedy_dedicatedDictSearch_row +#define ZSTD_COMPRESSBLOCK_GREEDY_EXTDICT ZSTD_compressBlock_greedy_extDict +#define ZSTD_COMPRESSBLOCK_GREEDY_EXTDICT_ROW ZSTD_compressBlock_greedy_extDict_row +#else +#define ZSTD_COMPRESSBLOCK_GREEDY NULL +#define ZSTD_COMPRESSBLOCK_GREEDY_ROW NULL +#define ZSTD_COMPRESSBLOCK_GREEDY_DICTMATCHSTATE NULL +#define ZSTD_COMPRESSBLOCK_GREEDY_DICTMATCHSTATE_ROW NULL +#define ZSTD_COMPRESSBLOCK_GREEDY_DEDICATEDDICTSEARCH NULL +#define ZSTD_COMPRESSBLOCK_GREEDY_DEDICATEDDICTSEARCH_ROW NULL +#define ZSTD_COMPRESSBLOCK_GREEDY_EXTDICT NULL +#define ZSTD_COMPRESSBLOCK_GREEDY_EXTDICT_ROW NULL +#endif + +#ifndef ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_lazy( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy2_dictMatchState( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_lazy_row( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); size_t ZSTD_compressBlock_lazy_dictMatchState( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], - void const* src, size_t srcSize); -size_t ZSTD_compressBlock_greedy_dictMatchState( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], - void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy2_dictMatchState_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); size_t ZSTD_compressBlock_lazy_dictMatchState_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_greedy_dictMatchState_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_lazy_dedicatedDictSearch( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); - -size_t ZSTD_compressBlock_lazy2_dedicatedDictSearch( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_lazy_dedicatedDictSearch_row( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy_dedicatedDictSearch( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_lazy_extDict( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_greedy_dedicatedDictSearch( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_lazy_extDict_row( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy2_dedicatedDictSearch_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + +#define ZSTD_COMPRESSBLOCK_LAZY ZSTD_compressBlock_lazy +#define ZSTD_COMPRESSBLOCK_LAZY_ROW ZSTD_compressBlock_lazy_row +#define ZSTD_COMPRESSBLOCK_LAZY_DICTMATCHSTATE ZSTD_compressBlock_lazy_dictMatchState +#define ZSTD_COMPRESSBLOCK_LAZY_DICTMATCHSTATE_ROW ZSTD_compressBlock_lazy_dictMatchState_row +#define ZSTD_COMPRESSBLOCK_LAZY_DEDICATEDDICTSEARCH ZSTD_compressBlock_lazy_dedicatedDictSearch +#define ZSTD_COMPRESSBLOCK_LAZY_DEDICATEDDICTSEARCH_ROW ZSTD_compressBlock_lazy_dedicatedDictSearch_row +#define ZSTD_COMPRESSBLOCK_LAZY_EXTDICT ZSTD_compressBlock_lazy_extDict +#define ZSTD_COMPRESSBLOCK_LAZY_EXTDICT_ROW ZSTD_compressBlock_lazy_extDict_row +#else +#define ZSTD_COMPRESSBLOCK_LAZY NULL +#define ZSTD_COMPRESSBLOCK_LAZY_ROW NULL +#define ZSTD_COMPRESSBLOCK_LAZY_DICTMATCHSTATE NULL +#define ZSTD_COMPRESSBLOCK_LAZY_DICTMATCHSTATE_ROW NULL +#define ZSTD_COMPRESSBLOCK_LAZY_DEDICATEDDICTSEARCH NULL +#define ZSTD_COMPRESSBLOCK_LAZY_DEDICATEDDICTSEARCH_ROW NULL +#define ZSTD_COMPRESSBLOCK_LAZY_EXTDICT NULL +#define ZSTD_COMPRESSBLOCK_LAZY_EXTDICT_ROW NULL +#endif + +#ifndef ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_lazy2( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy_dedicatedDictSearch_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_lazy2_row( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_greedy_dedicatedDictSearch_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_lazy2_dictMatchState( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); - -size_t ZSTD_compressBlock_greedy_extDict( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_lazy2_dictMatchState_row( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy_extDict( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_lazy2_dedicatedDictSearch( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +size_t ZSTD_compressBlock_lazy2_dedicatedDictSearch_row( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); size_t ZSTD_compressBlock_lazy2_extDict( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_greedy_extDict_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_lazy2_extDict_row( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy_extDict_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + +#define ZSTD_COMPRESSBLOCK_LAZY2 ZSTD_compressBlock_lazy2 +#define ZSTD_COMPRESSBLOCK_LAZY2_ROW ZSTD_compressBlock_lazy2_row +#define ZSTD_COMPRESSBLOCK_LAZY2_DICTMATCHSTATE ZSTD_compressBlock_lazy2_dictMatchState +#define ZSTD_COMPRESSBLOCK_LAZY2_DICTMATCHSTATE_ROW ZSTD_compressBlock_lazy2_dictMatchState_row +#define ZSTD_COMPRESSBLOCK_LAZY2_DEDICATEDDICTSEARCH ZSTD_compressBlock_lazy2_dedicatedDictSearch +#define ZSTD_COMPRESSBLOCK_LAZY2_DEDICATEDDICTSEARCH_ROW ZSTD_compressBlock_lazy2_dedicatedDictSearch_row +#define ZSTD_COMPRESSBLOCK_LAZY2_EXTDICT ZSTD_compressBlock_lazy2_extDict +#define ZSTD_COMPRESSBLOCK_LAZY2_EXTDICT_ROW ZSTD_compressBlock_lazy2_extDict_row +#else +#define ZSTD_COMPRESSBLOCK_LAZY2 NULL +#define ZSTD_COMPRESSBLOCK_LAZY2_ROW NULL +#define ZSTD_COMPRESSBLOCK_LAZY2_DICTMATCHSTATE NULL +#define ZSTD_COMPRESSBLOCK_LAZY2_DICTMATCHSTATE_ROW NULL +#define ZSTD_COMPRESSBLOCK_LAZY2_DEDICATEDDICTSEARCH NULL +#define ZSTD_COMPRESSBLOCK_LAZY2_DEDICATEDDICTSEARCH_ROW NULL +#define ZSTD_COMPRESSBLOCK_LAZY2_EXTDICT NULL +#define ZSTD_COMPRESSBLOCK_LAZY2_EXTDICT_ROW NULL +#endif + +#ifndef ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_btlazy2( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy2_extDict_row( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_btlazy2_dictMatchState( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); size_t ZSTD_compressBlock_btlazy2_extDict( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); - +#define ZSTD_COMPRESSBLOCK_BTLAZY2 ZSTD_compressBlock_btlazy2 +#define ZSTD_COMPRESSBLOCK_BTLAZY2_DICTMATCHSTATE ZSTD_compressBlock_btlazy2_dictMatchState +#define ZSTD_COMPRESSBLOCK_BTLAZY2_EXTDICT ZSTD_compressBlock_btlazy2_extDict +#else +#define ZSTD_COMPRESSBLOCK_BTLAZY2 NULL +#define ZSTD_COMPRESSBLOCK_BTLAZY2_DICTMATCHSTATE NULL +#define ZSTD_COMPRESSBLOCK_BTLAZY2_EXTDICT NULL +#endif #endif /* ZSTD_LAZY_H */ diff --git a/lib/zstd/compress/zstd_ldm.c b/lib/zstd/compress/zstd_ldm.c index dd86fc83e7dd..54eefad9cae6 100644 --- a/lib/zstd/compress/zstd_ldm.c +++ b/lib/zstd/compress/zstd_ldm.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -16,7 +17,7 @@ #include "zstd_double_fast.h" /* ZSTD_fillDoubleHashTable() */ #include "zstd_ldm_geartab.h" -#define LDM_BUCKET_SIZE_LOG 3 +#define LDM_BUCKET_SIZE_LOG 4 #define LDM_MIN_MATCH_LENGTH 64 #define LDM_HASH_RLOG 7 @@ -133,21 +134,35 @@ static size_t ZSTD_ldm_gear_feed(ldmRollingHashState_t* state, } void ZSTD_ldm_adjustParameters(ldmParams_t* params, - ZSTD_compressionParameters const* cParams) + const ZSTD_compressionParameters* cParams) { params->windowLog = cParams->windowLog; ZSTD_STATIC_ASSERT(LDM_BUCKET_SIZE_LOG <= ZSTD_LDM_BUCKETSIZELOG_MAX); DEBUGLOG(4, "ZSTD_ldm_adjustParameters"); - if (!params->bucketSizeLog) params->bucketSizeLog = LDM_BUCKET_SIZE_LOG; - if (!params->minMatchLength) params->minMatchLength = LDM_MIN_MATCH_LENGTH; + if (params->hashRateLog == 0) { + if (params->hashLog > 0) { + /* if params->hashLog is set, derive hashRateLog from it */ + assert(params->hashLog <= ZSTD_HASHLOG_MAX); + if (params->windowLog > params->hashLog) { + params->hashRateLog = params->windowLog - params->hashLog; + } + } else { + assert(1 <= (int)cParams->strategy && (int)cParams->strategy <= 9); + /* mapping from [fast, rate7] to [btultra2, rate4] */ + params->hashRateLog = 7 - (cParams->strategy/3); + } + } if (params->hashLog == 0) { - params->hashLog = MAX(ZSTD_HASHLOG_MIN, params->windowLog - LDM_HASH_RLOG); - assert(params->hashLog <= ZSTD_HASHLOG_MAX); + params->hashLog = BOUNDED(ZSTD_HASHLOG_MIN, params->windowLog - params->hashRateLog, ZSTD_HASHLOG_MAX); } - if (params->hashRateLog == 0) { - params->hashRateLog = params->windowLog < params->hashLog - ? 0 - : params->windowLog - params->hashLog; + if (params->minMatchLength == 0) { + params->minMatchLength = LDM_MIN_MATCH_LENGTH; + if (cParams->strategy >= ZSTD_btultra) + params->minMatchLength /= 2; + } + if (params->bucketSizeLog==0) { + assert(1 <= (int)cParams->strategy && (int)cParams->strategy <= 9); + params->bucketSizeLog = BOUNDED(LDM_BUCKET_SIZE_LOG, (U32)cParams->strategy, ZSTD_LDM_BUCKETSIZELOG_MAX); } params->bucketSizeLog = MIN(params->bucketSizeLog, params->hashLog); } @@ -170,22 +185,22 @@ size_t ZSTD_ldm_getMaxNbSeq(ldmParams_t params, size_t maxChunkSize) /* ZSTD_ldm_getBucket() : * Returns a pointer to the start of the bucket associated with hash. */ static ldmEntry_t* ZSTD_ldm_getBucket( - ldmState_t* ldmState, size_t hash, ldmParams_t const ldmParams) + const ldmState_t* ldmState, size_t hash, U32 const bucketSizeLog) { - return ldmState->hashTable + (hash << ldmParams.bucketSizeLog); + return ldmState->hashTable + (hash << bucketSizeLog); } /* ZSTD_ldm_insertEntry() : * Insert the entry with corresponding hash into the hash table */ static void ZSTD_ldm_insertEntry(ldmState_t* ldmState, size_t const hash, const ldmEntry_t entry, - ldmParams_t const ldmParams) + U32 const bucketSizeLog) { BYTE* const pOffset = ldmState->bucketOffsets + hash; unsigned const offset = *pOffset; - *(ZSTD_ldm_getBucket(ldmState, hash, ldmParams) + offset) = entry; - *pOffset = (BYTE)((offset + 1) & ((1u << ldmParams.bucketSizeLog) - 1)); + *(ZSTD_ldm_getBucket(ldmState, hash, bucketSizeLog) + offset) = entry; + *pOffset = (BYTE)((offset + 1) & ((1u << bucketSizeLog) - 1)); } @@ -234,7 +249,7 @@ static size_t ZSTD_ldm_countBackwardsMatch_2segments( * * The tables for the other strategies are filled within their * block compressors. */ -static size_t ZSTD_ldm_fillFastTables(ZSTD_matchState_t* ms, +static size_t ZSTD_ldm_fillFastTables(ZSTD_MatchState_t* ms, void const* end) { const BYTE* const iend = (const BYTE*)end; @@ -242,11 +257,15 @@ static size_t ZSTD_ldm_fillFastTables(ZSTD_matchState_t* ms, switch(ms->cParams.strategy) { case ZSTD_fast: - ZSTD_fillHashTable(ms, iend, ZSTD_dtlm_fast); + ZSTD_fillHashTable(ms, iend, ZSTD_dtlm_fast, ZSTD_tfp_forCCtx); break; case ZSTD_dfast: - ZSTD_fillDoubleHashTable(ms, iend, ZSTD_dtlm_fast); +#ifndef ZSTD_EXCLUDE_DFAST_BLOCK_COMPRESSOR + ZSTD_fillDoubleHashTable(ms, iend, ZSTD_dtlm_fast, ZSTD_tfp_forCCtx); +#else + assert(0); /* shouldn't be called: cparams should've been adjusted. */ +#endif break; case ZSTD_greedy: @@ -269,7 +288,8 @@ void ZSTD_ldm_fillHashTable( const BYTE* iend, ldmParams_t const* params) { U32 const minMatchLength = params->minMatchLength; - U32 const hBits = params->hashLog - params->bucketSizeLog; + U32 const bucketSizeLog = params->bucketSizeLog; + U32 const hBits = params->hashLog - bucketSizeLog; BYTE const* const base = ldmState->window.base; BYTE const* const istart = ip; ldmRollingHashState_t hashState; @@ -284,7 +304,7 @@ void ZSTD_ldm_fillHashTable( unsigned n; numSplits = 0; - hashed = ZSTD_ldm_gear_feed(&hashState, ip, iend - ip, splits, &numSplits); + hashed = ZSTD_ldm_gear_feed(&hashState, ip, (size_t)(iend - ip), splits, &numSplits); for (n = 0; n < numSplits; n++) { if (ip + splits[n] >= istart + minMatchLength) { @@ -295,7 +315,7 @@ void ZSTD_ldm_fillHashTable( entry.offset = (U32)(split - base); entry.checksum = (U32)(xxhash >> 32); - ZSTD_ldm_insertEntry(ldmState, hash, entry, *params); + ZSTD_ldm_insertEntry(ldmState, hash, entry, params->bucketSizeLog); } } @@ -309,7 +329,7 @@ void ZSTD_ldm_fillHashTable( * Sets cctx->nextToUpdate to a position corresponding closer to anchor * if it is far way * (after a long match, only update tables a limited amount). */ -static void ZSTD_ldm_limitTableUpdate(ZSTD_matchState_t* ms, const BYTE* anchor) +static void ZSTD_ldm_limitTableUpdate(ZSTD_MatchState_t* ms, const BYTE* anchor) { U32 const curr = (U32)(anchor - ms->window.base); if (curr > ms->nextToUpdate + 1024) { @@ -318,8 +338,10 @@ static void ZSTD_ldm_limitTableUpdate(ZSTD_matchState_t* ms, const BYTE* anchor) } } -static size_t ZSTD_ldm_generateSequences_internal( - ldmState_t* ldmState, rawSeqStore_t* rawSeqStore, +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +size_t ZSTD_ldm_generateSequences_internal( + ldmState_t* ldmState, RawSeqStore_t* rawSeqStore, ldmParams_t const* params, void const* src, size_t srcSize) { /* LDM parameters */ @@ -373,7 +395,7 @@ static size_t ZSTD_ldm_generateSequences_internal( candidates[n].split = split; candidates[n].hash = hash; candidates[n].checksum = (U32)(xxhash >> 32); - candidates[n].bucket = ZSTD_ldm_getBucket(ldmState, hash, *params); + candidates[n].bucket = ZSTD_ldm_getBucket(ldmState, hash, params->bucketSizeLog); PREFETCH_L1(candidates[n].bucket); } @@ -396,7 +418,7 @@ static size_t ZSTD_ldm_generateSequences_internal( * the previous one, we merely register it in the hash table and * move on */ if (split < anchor) { - ZSTD_ldm_insertEntry(ldmState, hash, newEntry, *params); + ZSTD_ldm_insertEntry(ldmState, hash, newEntry, params->bucketSizeLog); continue; } @@ -443,7 +465,7 @@ static size_t ZSTD_ldm_generateSequences_internal( /* No match found -- insert an entry into the hash table * and process the next candidate match */ if (bestEntry == NULL) { - ZSTD_ldm_insertEntry(ldmState, hash, newEntry, *params); + ZSTD_ldm_insertEntry(ldmState, hash, newEntry, params->bucketSizeLog); continue; } @@ -464,7 +486,7 @@ static size_t ZSTD_ldm_generateSequences_internal( /* Insert the current entry into the hash table --- it must be * done after the previous block to avoid clobbering bestEntry */ - ZSTD_ldm_insertEntry(ldmState, hash, newEntry, *params); + ZSTD_ldm_insertEntry(ldmState, hash, newEntry, params->bucketSizeLog); anchor = split + forwardMatchLength; @@ -503,7 +525,7 @@ static void ZSTD_ldm_reduceTable(ldmEntry_t* const table, U32 const size, } size_t ZSTD_ldm_generateSequences( - ldmState_t* ldmState, rawSeqStore_t* sequences, + ldmState_t* ldmState, RawSeqStore_t* sequences, ldmParams_t const* params, void const* src, size_t srcSize) { U32 const maxDist = 1U << params->windowLog; @@ -549,7 +571,7 @@ size_t ZSTD_ldm_generateSequences( * the window through early invalidation. * TODO: * Test the chunk size. * * Try invalidation after the sequence generation and test the - * the offset against maxDist directly. + * offset against maxDist directly. * * NOTE: Because of dictionaries + sequence splitting we MUST make sure * that any offset used is valid at the END of the sequence, since it may @@ -580,7 +602,7 @@ size_t ZSTD_ldm_generateSequences( } void -ZSTD_ldm_skipSequences(rawSeqStore_t* rawSeqStore, size_t srcSize, U32 const minMatch) +ZSTD_ldm_skipSequences(RawSeqStore_t* rawSeqStore, size_t srcSize, U32 const minMatch) { while (srcSize > 0 && rawSeqStore->pos < rawSeqStore->size) { rawSeq* seq = rawSeqStore->seq + rawSeqStore->pos; @@ -616,7 +638,7 @@ ZSTD_ldm_skipSequences(rawSeqStore_t* rawSeqStore, size_t srcSize, U32 const min * Returns the current sequence to handle, or if the rest of the block should * be literals, it returns a sequence with offset == 0. */ -static rawSeq maybeSplitSequence(rawSeqStore_t* rawSeqStore, +static rawSeq maybeSplitSequence(RawSeqStore_t* rawSeqStore, U32 const remaining, U32 const minMatch) { rawSeq sequence = rawSeqStore->seq[rawSeqStore->pos]; @@ -640,7 +662,7 @@ static rawSeq maybeSplitSequence(rawSeqStore_t* rawSeqStore, return sequence; } -void ZSTD_ldm_skipRawSeqStoreBytes(rawSeqStore_t* rawSeqStore, size_t nbBytes) { +void ZSTD_ldm_skipRawSeqStoreBytes(RawSeqStore_t* rawSeqStore, size_t nbBytes) { U32 currPos = (U32)(rawSeqStore->posInSequence + nbBytes); while (currPos && rawSeqStore->pos < rawSeqStore->size) { rawSeq currSeq = rawSeqStore->seq[rawSeqStore->pos]; @@ -657,14 +679,14 @@ void ZSTD_ldm_skipRawSeqStoreBytes(rawSeqStore_t* rawSeqStore, size_t nbBytes) { } } -size_t ZSTD_ldm_blockCompress(rawSeqStore_t* rawSeqStore, - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], - ZSTD_paramSwitch_e useRowMatchFinder, +size_t ZSTD_ldm_blockCompress(RawSeqStore_t* rawSeqStore, + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_ParamSwitch_e useRowMatchFinder, void const* src, size_t srcSize) { const ZSTD_compressionParameters* const cParams = &ms->cParams; unsigned const minMatch = cParams->minMatch; - ZSTD_blockCompressor const blockCompressor = + ZSTD_BlockCompressor_f const blockCompressor = ZSTD_selectBlockCompressor(cParams->strategy, useRowMatchFinder, ZSTD_matchState_dictMode(ms)); /* Input bounds */ BYTE const* const istart = (BYTE const*)src; @@ -689,7 +711,6 @@ size_t ZSTD_ldm_blockCompress(rawSeqStore_t* rawSeqStore, /* maybeSplitSequence updates rawSeqStore->pos */ rawSeq const sequence = maybeSplitSequence(rawSeqStore, (U32)(iend - ip), minMatch); - int i; /* End signal */ if (sequence.offset == 0) break; @@ -702,6 +723,7 @@ size_t ZSTD_ldm_blockCompress(rawSeqStore_t* rawSeqStore, /* Run the block compressor */ DEBUGLOG(5, "pos %u : calling block compressor on segment of size %u", (unsigned)(ip-istart), sequence.litLength); { + int i; size_t const newLitLength = blockCompressor(ms, seqStore, rep, ip, sequence.litLength); ip += sequence.litLength; @@ -711,7 +733,7 @@ size_t ZSTD_ldm_blockCompress(rawSeqStore_t* rawSeqStore, rep[0] = sequence.offset; /* Store the sequence */ ZSTD_storeSeq(seqStore, newLitLength, ip - newLitLength, iend, - STORE_OFFSET(sequence.offset), + OFFSET_TO_OFFBASE(sequence.offset), sequence.matchLength); ip += sequence.matchLength; } diff --git a/lib/zstd/compress/zstd_ldm.h b/lib/zstd/compress/zstd_ldm.h index fbc6a5e88fd7..41400a7191b2 100644 --- a/lib/zstd/compress/zstd_ldm.h +++ b/lib/zstd/compress/zstd_ldm.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -11,7 +12,6 @@ #ifndef ZSTD_LDM_H #define ZSTD_LDM_H - #include "zstd_compress_internal.h" /* ldmParams_t, U32 */ #include /* ZSTD_CCtx, size_t */ @@ -40,7 +40,7 @@ void ZSTD_ldm_fillHashTable( * sequences. */ size_t ZSTD_ldm_generateSequences( - ldmState_t* ldms, rawSeqStore_t* sequences, + ldmState_t* ldms, RawSeqStore_t* sequences, ldmParams_t const* params, void const* src, size_t srcSize); /* @@ -61,9 +61,9 @@ size_t ZSTD_ldm_generateSequences( * two. We handle that case correctly, and update `rawSeqStore` appropriately. * NOTE: This function does not return any errors. */ -size_t ZSTD_ldm_blockCompress(rawSeqStore_t* rawSeqStore, - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], - ZSTD_paramSwitch_e useRowMatchFinder, +size_t ZSTD_ldm_blockCompress(RawSeqStore_t* rawSeqStore, + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_ParamSwitch_e useRowMatchFinder, void const* src, size_t srcSize); /* @@ -73,7 +73,7 @@ size_t ZSTD_ldm_blockCompress(rawSeqStore_t* rawSeqStore, * Avoids emitting matches less than `minMatch` bytes. * Must be called for data that is not passed to ZSTD_ldm_blockCompress(). */ -void ZSTD_ldm_skipSequences(rawSeqStore_t* rawSeqStore, size_t srcSize, +void ZSTD_ldm_skipSequences(RawSeqStore_t* rawSeqStore, size_t srcSize, U32 const minMatch); /* ZSTD_ldm_skipRawSeqStoreBytes(): @@ -81,7 +81,7 @@ void ZSTD_ldm_skipSequences(rawSeqStore_t* rawSeqStore, size_t srcSize, * Not to be used in conjunction with ZSTD_ldm_skipSequences(). * Must be called for data with is not passed to ZSTD_ldm_blockCompress(). */ -void ZSTD_ldm_skipRawSeqStoreBytes(rawSeqStore_t* rawSeqStore, size_t nbBytes); +void ZSTD_ldm_skipRawSeqStoreBytes(RawSeqStore_t* rawSeqStore, size_t nbBytes); /* ZSTD_ldm_getTableSize() : * Estimate the space needed for long distance matching tables or 0 if LDM is @@ -107,5 +107,4 @@ size_t ZSTD_ldm_getMaxNbSeq(ldmParams_t params, size_t maxChunkSize); void ZSTD_ldm_adjustParameters(ldmParams_t* params, ZSTD_compressionParameters const* cParams); - #endif /* ZSTD_FAST_H */ diff --git a/lib/zstd/compress/zstd_ldm_geartab.h b/lib/zstd/compress/zstd_ldm_geartab.h index 647f865be290..cfccfc46f6f7 100644 --- a/lib/zstd/compress/zstd_ldm_geartab.h +++ b/lib/zstd/compress/zstd_ldm_geartab.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the diff --git a/lib/zstd/compress/zstd_opt.c b/lib/zstd/compress/zstd_opt.c index fd82acfda62f..b62fd1b0d83e 100644 --- a/lib/zstd/compress/zstd_opt.c +++ b/lib/zstd/compress/zstd_opt.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Przemyslaw Skibinski, Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -12,11 +13,14 @@ #include "hist.h" #include "zstd_opt.h" +#if !defined(ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR) #define ZSTD_LITFREQ_ADD 2 /* scaling factor for litFreq, so that frequencies adapt faster to new stats */ #define ZSTD_MAX_PRICE (1<<30) -#define ZSTD_PREDEF_THRESHOLD 1024 /* if srcSize < ZSTD_PREDEF_THRESHOLD, symbols' cost is assumed static, directly determined by pre-defined distributions */ +#define ZSTD_PREDEF_THRESHOLD 8 /* if srcSize < ZSTD_PREDEF_THRESHOLD, symbols' cost is assumed static, directly determined by pre-defined distributions */ /*-************************************* @@ -26,27 +30,35 @@ #if 0 /* approximation at bit level (for tests) */ # define BITCOST_ACCURACY 0 # define BITCOST_MULTIPLIER (1 << BITCOST_ACCURACY) -# define WEIGHT(stat, opt) ((void)opt, ZSTD_bitWeight(stat)) +# define WEIGHT(stat, opt) ((void)(opt), ZSTD_bitWeight(stat)) #elif 0 /* fractional bit accuracy (for tests) */ # define BITCOST_ACCURACY 8 # define BITCOST_MULTIPLIER (1 << BITCOST_ACCURACY) -# define WEIGHT(stat,opt) ((void)opt, ZSTD_fracWeight(stat)) +# define WEIGHT(stat,opt) ((void)(opt), ZSTD_fracWeight(stat)) #else /* opt==approx, ultra==accurate */ # define BITCOST_ACCURACY 8 # define BITCOST_MULTIPLIER (1 << BITCOST_ACCURACY) -# define WEIGHT(stat,opt) (opt ? ZSTD_fracWeight(stat) : ZSTD_bitWeight(stat)) +# define WEIGHT(stat,opt) ((opt) ? ZSTD_fracWeight(stat) : ZSTD_bitWeight(stat)) #endif +/* ZSTD_bitWeight() : + * provide estimated "cost" of a stat in full bits only */ MEM_STATIC U32 ZSTD_bitWeight(U32 stat) { return (ZSTD_highbit32(stat+1) * BITCOST_MULTIPLIER); } +/* ZSTD_fracWeight() : + * provide fractional-bit "cost" of a stat, + * using linear interpolation approximation */ MEM_STATIC U32 ZSTD_fracWeight(U32 rawStat) { U32 const stat = rawStat + 1; U32 const hb = ZSTD_highbit32(stat); U32 const BWeight = hb * BITCOST_MULTIPLIER; + /* Fweight was meant for "Fractional weight" + * but it's effectively a value between 1 and 2 + * using fixed point arithmetic */ U32 const FWeight = (stat << BITCOST_ACCURACY) >> hb; U32 const weight = BWeight + FWeight; assert(hb + BITCOST_ACCURACY < 31); @@ -57,7 +69,7 @@ MEM_STATIC U32 ZSTD_fracWeight(U32 rawStat) /* debugging function, * @return price in bytes as fractional value * for debug messages only */ -MEM_STATIC double ZSTD_fCost(U32 price) +MEM_STATIC double ZSTD_fCost(int price) { return (double)price / (BITCOST_MULTIPLIER*8); } @@ -88,20 +100,26 @@ static U32 sum_u32(const unsigned table[], size_t nbElts) return total; } -static U32 ZSTD_downscaleStats(unsigned* table, U32 lastEltIndex, U32 shift) +typedef enum { base_0possible=0, base_1guaranteed=1 } base_directive_e; + +static U32 +ZSTD_downscaleStats(unsigned* table, U32 lastEltIndex, U32 shift, base_directive_e base1) { U32 s, sum=0; - DEBUGLOG(5, "ZSTD_downscaleStats (nbElts=%u, shift=%u)", (unsigned)lastEltIndex+1, (unsigned)shift); + DEBUGLOG(5, "ZSTD_downscaleStats (nbElts=%u, shift=%u)", + (unsigned)lastEltIndex+1, (unsigned)shift ); assert(shift < 30); for (s=0; s> shift); - sum += table[s]; + unsigned const base = base1 ? 1 : (table[s]>0); + unsigned const newStat = base + (table[s] >> shift); + sum += newStat; + table[s] = newStat; } return sum; } /* ZSTD_scaleStats() : - * reduce all elements in table is sum too large + * reduce all elt frequencies in table if sum too large * return the resulting sum of elements */ static U32 ZSTD_scaleStats(unsigned* table, U32 lastEltIndex, U32 logTarget) { @@ -110,7 +128,7 @@ static U32 ZSTD_scaleStats(unsigned* table, U32 lastEltIndex, U32 logTarget) DEBUGLOG(5, "ZSTD_scaleStats (nbElts=%u, target=%u)", (unsigned)lastEltIndex+1, (unsigned)logTarget); assert(logTarget < 30); if (factor <= 1) return prevsum; - return ZSTD_downscaleStats(table, lastEltIndex, ZSTD_highbit32(factor)); + return ZSTD_downscaleStats(table, lastEltIndex, ZSTD_highbit32(factor), base_1guaranteed); } /* ZSTD_rescaleFreqs() : @@ -129,18 +147,22 @@ ZSTD_rescaleFreqs(optState_t* const optPtr, DEBUGLOG(5, "ZSTD_rescaleFreqs (srcSize=%u)", (unsigned)srcSize); optPtr->priceType = zop_dynamic; - if (optPtr->litLengthSum == 0) { /* first block : init */ - if (srcSize <= ZSTD_PREDEF_THRESHOLD) { /* heuristic */ - DEBUGLOG(5, "(srcSize <= ZSTD_PREDEF_THRESHOLD) => zop_predef"); + if (optPtr->litLengthSum == 0) { /* no literals stats collected -> first block assumed -> init */ + + /* heuristic: use pre-defined stats for too small inputs */ + if (srcSize <= ZSTD_PREDEF_THRESHOLD) { + DEBUGLOG(5, "srcSize <= %i : use predefined stats", ZSTD_PREDEF_THRESHOLD); optPtr->priceType = zop_predef; } assert(optPtr->symbolCosts != NULL); if (optPtr->symbolCosts->huf.repeatMode == HUF_repeat_valid) { - /* huffman table presumed generated by dictionary */ + + /* huffman stats covering the full value set : table presumed generated by dictionary */ optPtr->priceType = zop_dynamic; if (compressedLiterals) { + /* generate literals statistics from huffman table */ unsigned lit; assert(optPtr->litFreq != NULL); optPtr->litSum = 0; @@ -188,13 +210,14 @@ ZSTD_rescaleFreqs(optState_t* const optPtr, optPtr->offCodeSum += optPtr->offCodeFreq[of]; } } - } else { /* not a dictionary */ + } else { /* first block, no dictionary */ assert(optPtr->litFreq != NULL); if (compressedLiterals) { + /* base initial cost of literals on direct frequency within src */ unsigned lit = MaxLit; HIST_count_simple(optPtr->litFreq, &lit, src, srcSize); /* use raw first block to init statistics */ - optPtr->litSum = ZSTD_downscaleStats(optPtr->litFreq, MaxLit, 8); + optPtr->litSum = ZSTD_downscaleStats(optPtr->litFreq, MaxLit, 8, base_0possible); } { unsigned const baseLLfreqs[MaxLL+1] = { @@ -224,10 +247,9 @@ ZSTD_rescaleFreqs(optState_t* const optPtr, optPtr->offCodeSum = sum_u32(baseOFCfreqs, MaxOff+1); } - } - } else { /* new block : re-use previous statistics, scaled down */ + } else { /* new block : scale down accumulated statistics */ if (compressedLiterals) optPtr->litSum = ZSTD_scaleStats(optPtr->litFreq, MaxLit, 12); @@ -246,6 +268,7 @@ static U32 ZSTD_rawLiteralsCost(const BYTE* const literals, U32 const litLength, const optState_t* const optPtr, int optLevel) { + DEBUGLOG(8, "ZSTD_rawLiteralsCost (%u literals)", litLength); if (litLength == 0) return 0; if (!ZSTD_compressedLiterals(optPtr)) @@ -255,11 +278,14 @@ static U32 ZSTD_rawLiteralsCost(const BYTE* const literals, U32 const litLength, return (litLength*6) * BITCOST_MULTIPLIER; /* 6 bit per literal - no statistic used */ /* dynamic statistics */ - { U32 price = litLength * optPtr->litSumBasePrice; + { U32 price = optPtr->litSumBasePrice * litLength; + U32 const litPriceMax = optPtr->litSumBasePrice - BITCOST_MULTIPLIER; U32 u; + assert(optPtr->litSumBasePrice >= BITCOST_MULTIPLIER); for (u=0; u < litLength; u++) { - assert(WEIGHT(optPtr->litFreq[literals[u]], optLevel) <= optPtr->litSumBasePrice); /* literal cost should never be negative */ - price -= WEIGHT(optPtr->litFreq[literals[u]], optLevel); + U32 litPrice = WEIGHT(optPtr->litFreq[literals[u]], optLevel); + if (UNLIKELY(litPrice > litPriceMax)) litPrice = litPriceMax; + price -= litPrice; } return price; } @@ -272,10 +298,11 @@ static U32 ZSTD_litLengthPrice(U32 const litLength, const optState_t* const optP assert(litLength <= ZSTD_BLOCKSIZE_MAX); if (optPtr->priceType == zop_predef) return WEIGHT(litLength, optLevel); - /* We can't compute the litLength price for sizes >= ZSTD_BLOCKSIZE_MAX - * because it isn't representable in the zstd format. So instead just - * call it 1 bit more than ZSTD_BLOCKSIZE_MAX - 1. In this case the block - * would be all literals. + + /* ZSTD_LLcode() can't compute litLength price for sizes >= ZSTD_BLOCKSIZE_MAX + * because it isn't representable in the zstd format. + * So instead just pretend it would cost 1 bit more than ZSTD_BLOCKSIZE_MAX - 1. + * In such a case, the block would be all literals. */ if (litLength == ZSTD_BLOCKSIZE_MAX) return BITCOST_MULTIPLIER + ZSTD_litLengthPrice(ZSTD_BLOCKSIZE_MAX - 1, optPtr, optLevel); @@ -289,24 +316,25 @@ static U32 ZSTD_litLengthPrice(U32 const litLength, const optState_t* const optP } /* ZSTD_getMatchPrice() : - * Provides the cost of the match part (offset + matchLength) of a sequence + * Provides the cost of the match part (offset + matchLength) of a sequence. * Must be combined with ZSTD_fullLiteralsCost() to get the full cost of a sequence. - * @offcode : expects a scale where 0,1,2 are repcodes 1-3, and 3+ are real_offsets+2 + * @offBase : sumtype, representing an offset or a repcode, and using numeric representation of ZSTD_storeSeq() * @optLevel: when <2, favors small offset for decompression speed (improved cache efficiency) */ FORCE_INLINE_TEMPLATE U32 -ZSTD_getMatchPrice(U32 const offcode, +ZSTD_getMatchPrice(U32 const offBase, U32 const matchLength, const optState_t* const optPtr, int const optLevel) { U32 price; - U32 const offCode = ZSTD_highbit32(STORED_TO_OFFBASE(offcode)); + U32 const offCode = ZSTD_highbit32(offBase); U32 const mlBase = matchLength - MINMATCH; assert(matchLength >= MINMATCH); - if (optPtr->priceType == zop_predef) /* fixed scheme, do not use statistics */ - return WEIGHT(mlBase, optLevel) + ((16 + offCode) * BITCOST_MULTIPLIER); + if (optPtr->priceType == zop_predef) /* fixed scheme, does not use statistics */ + return WEIGHT(mlBase, optLevel) + + ((16 + offCode) * BITCOST_MULTIPLIER); /* emulated offset cost */ /* dynamic statistics */ price = (offCode * BITCOST_MULTIPLIER) + (optPtr->offCodeSumBasePrice - WEIGHT(optPtr->offCodeFreq[offCode], optLevel)); @@ -325,10 +353,10 @@ ZSTD_getMatchPrice(U32 const offcode, } /* ZSTD_updateStats() : - * assumption : literals + litLengtn <= iend */ + * assumption : literals + litLength <= iend */ static void ZSTD_updateStats(optState_t* const optPtr, U32 litLength, const BYTE* literals, - U32 offsetCode, U32 matchLength) + U32 offBase, U32 matchLength) { /* literals */ if (ZSTD_compressedLiterals(optPtr)) { @@ -344,8 +372,8 @@ static void ZSTD_updateStats(optState_t* const optPtr, optPtr->litLengthSum++; } - /* offset code : expected to follow storeSeq() numeric representation */ - { U32 const offCode = ZSTD_highbit32(STORED_TO_OFFBASE(offsetCode)); + /* offset code : follows storeSeq() numeric representation */ + { U32 const offCode = ZSTD_highbit32(offBase); assert(offCode <= MaxOff); optPtr->offCodeFreq[offCode]++; optPtr->offCodeSum++; @@ -379,9 +407,11 @@ MEM_STATIC U32 ZSTD_readMINMATCH(const void* memPtr, U32 length) /* Update hashTable3 up to ip (excluded) Assumption : always within prefix (i.e. not within extDict) */ -static U32 ZSTD_insertAndFindFirstIndexHash3 (const ZSTD_matchState_t* ms, - U32* nextToUpdate3, - const BYTE* const ip) +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +U32 ZSTD_insertAndFindFirstIndexHash3 (const ZSTD_MatchState_t* ms, + U32* nextToUpdate3, + const BYTE* const ip) { U32* const hashTable3 = ms->hashTable3; U32 const hashLog3 = ms->hashLog3; @@ -408,8 +438,10 @@ static U32 ZSTD_insertAndFindFirstIndexHash3 (const ZSTD_matchState_t* ms, * @param ip assumed <= iend-8 . * @param target The target of ZSTD_updateTree_internal() - we are filling to this position * @return : nb of positions added */ -static U32 ZSTD_insertBt1( - const ZSTD_matchState_t* ms, +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +U32 ZSTD_insertBt1( + const ZSTD_MatchState_t* ms, const BYTE* const ip, const BYTE* const iend, U32 const target, U32 const mls, const int extDict) @@ -527,15 +559,16 @@ static U32 ZSTD_insertBt1( } FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR void ZSTD_updateTree_internal( - ZSTD_matchState_t* ms, + ZSTD_MatchState_t* ms, const BYTE* const ip, const BYTE* const iend, const U32 mls, const ZSTD_dictMode_e dictMode) { const BYTE* const base = ms->window.base; U32 const target = (U32)(ip - base); U32 idx = ms->nextToUpdate; - DEBUGLOG(6, "ZSTD_updateTree_internal, from %u to %u (dictMode:%u)", + DEBUGLOG(7, "ZSTD_updateTree_internal, from %u to %u (dictMode:%u)", idx, target, dictMode); while(idx < target) { @@ -548,20 +581,23 @@ void ZSTD_updateTree_internal( ms->nextToUpdate = target; } -void ZSTD_updateTree(ZSTD_matchState_t* ms, const BYTE* ip, const BYTE* iend) { +void ZSTD_updateTree(ZSTD_MatchState_t* ms, const BYTE* ip, const BYTE* iend) { ZSTD_updateTree_internal(ms, ip, iend, ms->cParams.minMatch, ZSTD_noDict); } FORCE_INLINE_TEMPLATE -U32 ZSTD_insertBtAndGetAllMatches ( - ZSTD_match_t* matches, /* store result (found matches) in this table (presumed large enough) */ - ZSTD_matchState_t* ms, - U32* nextToUpdate3, - const BYTE* const ip, const BYTE* const iLimit, const ZSTD_dictMode_e dictMode, - const U32 rep[ZSTD_REP_NUM], - U32 const ll0, /* tells if associated literal length is 0 or not. This value must be 0 or 1 */ - const U32 lengthToBeat, - U32 const mls /* template */) +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +U32 +ZSTD_insertBtAndGetAllMatches ( + ZSTD_match_t* matches, /* store result (found matches) in this table (presumed large enough) */ + ZSTD_MatchState_t* ms, + U32* nextToUpdate3, + const BYTE* const ip, const BYTE* const iLimit, + const ZSTD_dictMode_e dictMode, + const U32 rep[ZSTD_REP_NUM], + const U32 ll0, /* tells if associated literal length is 0 or not. This value must be 0 or 1 */ + const U32 lengthToBeat, + const U32 mls /* template */) { const ZSTD_compressionParameters* const cParams = &ms->cParams; U32 const sufficient_len = MIN(cParams->targetLength, ZSTD_OPT_NUM -1); @@ -590,7 +626,7 @@ U32 ZSTD_insertBtAndGetAllMatches ( U32 mnum = 0; U32 nbCompares = 1U << cParams->searchLog; - const ZSTD_matchState_t* dms = dictMode == ZSTD_dictMatchState ? ms->dictMatchState : NULL; + const ZSTD_MatchState_t* dms = dictMode == ZSTD_dictMatchState ? ms->dictMatchState : NULL; const ZSTD_compressionParameters* const dmsCParams = dictMode == ZSTD_dictMatchState ? &dms->cParams : NULL; const BYTE* const dmsBase = dictMode == ZSTD_dictMatchState ? dms->window.base : NULL; @@ -629,13 +665,13 @@ U32 ZSTD_insertBtAndGetAllMatches ( assert(curr >= windowLow); if ( dictMode == ZSTD_extDict && ( ((repOffset-1) /*intentional overflow*/ < curr - windowLow) /* equivalent to `curr > repIndex >= windowLow` */ - & (((U32)((dictLimit-1) - repIndex) >= 3) ) /* intentional overflow : do not test positions overlapping 2 memory segments */) + & (ZSTD_index_overlap_check(dictLimit, repIndex)) ) && (ZSTD_readMINMATCH(ip, minMatch) == ZSTD_readMINMATCH(repMatch, minMatch)) ) { repLen = (U32)ZSTD_count_2segments(ip+minMatch, repMatch+minMatch, iLimit, dictEnd, prefixStart) + minMatch; } if (dictMode == ZSTD_dictMatchState && ( ((repOffset-1) /*intentional overflow*/ < curr - (dmsLowLimit + dmsIndexDelta)) /* equivalent to `curr > repIndex >= dmsLowLimit` */ - & ((U32)((dictLimit-1) - repIndex) >= 3) ) /* intentional overflow : do not test positions overlapping 2 memory segments */ + & (ZSTD_index_overlap_check(dictLimit, repIndex)) ) && (ZSTD_readMINMATCH(ip, minMatch) == ZSTD_readMINMATCH(repMatch, minMatch)) ) { repLen = (U32)ZSTD_count_2segments(ip+minMatch, repMatch+minMatch, iLimit, dmsEnd, prefixStart) + minMatch; } } @@ -644,7 +680,7 @@ U32 ZSTD_insertBtAndGetAllMatches ( DEBUGLOG(8, "found repCode %u (ll0:%u, offset:%u) of length %u", repCode, ll0, repOffset, repLen); bestLength = repLen; - matches[mnum].off = STORE_REPCODE(repCode - ll0 + 1); /* expect value between 1 and 3 */ + matches[mnum].off = REPCODE_TO_OFFBASE(repCode - ll0 + 1); /* expect value between 1 and 3 */ matches[mnum].len = (U32)repLen; mnum++; if ( (repLen > sufficient_len) @@ -673,7 +709,7 @@ U32 ZSTD_insertBtAndGetAllMatches ( bestLength = mlen; assert(curr > matchIndex3); assert(mnum==0); /* no prior solution */ - matches[0].off = STORE_OFFSET(curr - matchIndex3); + matches[0].off = OFFSET_TO_OFFBASE(curr - matchIndex3); matches[0].len = (U32)mlen; mnum = 1; if ( (mlen > sufficient_len) | @@ -706,13 +742,13 @@ U32 ZSTD_insertBtAndGetAllMatches ( } if (matchLength > bestLength) { - DEBUGLOG(8, "found match of length %u at distance %u (offCode=%u)", - (U32)matchLength, curr - matchIndex, STORE_OFFSET(curr - matchIndex)); + DEBUGLOG(8, "found match of length %u at distance %u (offBase=%u)", + (U32)matchLength, curr - matchIndex, OFFSET_TO_OFFBASE(curr - matchIndex)); assert(matchEndIdx > matchIndex); if (matchLength > matchEndIdx - matchIndex) matchEndIdx = matchIndex + (U32)matchLength; bestLength = matchLength; - matches[mnum].off = STORE_OFFSET(curr - matchIndex); + matches[mnum].off = OFFSET_TO_OFFBASE(curr - matchIndex); matches[mnum].len = (U32)matchLength; mnum++; if ( (matchLength > ZSTD_OPT_NUM) @@ -754,12 +790,12 @@ U32 ZSTD_insertBtAndGetAllMatches ( if (matchLength > bestLength) { matchIndex = dictMatchIndex + dmsIndexDelta; - DEBUGLOG(8, "found dms match of length %u at distance %u (offCode=%u)", - (U32)matchLength, curr - matchIndex, STORE_OFFSET(curr - matchIndex)); + DEBUGLOG(8, "found dms match of length %u at distance %u (offBase=%u)", + (U32)matchLength, curr - matchIndex, OFFSET_TO_OFFBASE(curr - matchIndex)); if (matchLength > matchEndIdx - matchIndex) matchEndIdx = matchIndex + (U32)matchLength; bestLength = matchLength; - matches[mnum].off = STORE_OFFSET(curr - matchIndex); + matches[mnum].off = OFFSET_TO_OFFBASE(curr - matchIndex); matches[mnum].len = (U32)matchLength; mnum++; if ( (matchLength > ZSTD_OPT_NUM) @@ -784,7 +820,7 @@ U32 ZSTD_insertBtAndGetAllMatches ( typedef U32 (*ZSTD_getAllMatchesFn)( ZSTD_match_t*, - ZSTD_matchState_t*, + ZSTD_MatchState_t*, U32*, const BYTE*, const BYTE*, @@ -792,9 +828,11 @@ typedef U32 (*ZSTD_getAllMatchesFn)( U32 const ll0, U32 const lengthToBeat); -FORCE_INLINE_TEMPLATE U32 ZSTD_btGetAllMatches_internal( +FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +U32 ZSTD_btGetAllMatches_internal( ZSTD_match_t* matches, - ZSTD_matchState_t* ms, + ZSTD_MatchState_t* ms, U32* nextToUpdate3, const BYTE* ip, const BYTE* const iHighLimit, @@ -817,7 +855,7 @@ FORCE_INLINE_TEMPLATE U32 ZSTD_btGetAllMatches_internal( #define GEN_ZSTD_BT_GET_ALL_MATCHES_(dictMode, mls) \ static U32 ZSTD_BT_GET_ALL_MATCHES_FN(dictMode, mls)( \ ZSTD_match_t* matches, \ - ZSTD_matchState_t* ms, \ + ZSTD_MatchState_t* ms, \ U32* nextToUpdate3, \ const BYTE* ip, \ const BYTE* const iHighLimit, \ @@ -849,7 +887,7 @@ GEN_ZSTD_BT_GET_ALL_MATCHES(dictMatchState) } static ZSTD_getAllMatchesFn -ZSTD_selectBtGetAllMatches(ZSTD_matchState_t const* ms, ZSTD_dictMode_e const dictMode) +ZSTD_selectBtGetAllMatches(ZSTD_MatchState_t const* ms, ZSTD_dictMode_e const dictMode) { ZSTD_getAllMatchesFn const getAllMatchesFns[3][4] = { ZSTD_BT_GET_ALL_MATCHES_ARRAY(noDict), @@ -868,7 +906,7 @@ ZSTD_selectBtGetAllMatches(ZSTD_matchState_t const* ms, ZSTD_dictMode_e const di /* Struct containing info needed to make decision about ldm inclusion */ typedef struct { - rawSeqStore_t seqStore; /* External match candidates store for this block */ + RawSeqStore_t seqStore; /* External match candidates store for this block */ U32 startPosInBlock; /* Start position of the current match candidate */ U32 endPosInBlock; /* End position of the current match candidate */ U32 offset; /* Offset of the match candidate */ @@ -878,7 +916,7 @@ typedef struct { * Moves forward in @rawSeqStore by @nbBytes, * which will update the fields 'pos' and 'posInSequence'. */ -static void ZSTD_optLdm_skipRawSeqStoreBytes(rawSeqStore_t* rawSeqStore, size_t nbBytes) +static void ZSTD_optLdm_skipRawSeqStoreBytes(RawSeqStore_t* rawSeqStore, size_t nbBytes) { U32 currPos = (U32)(rawSeqStore->posInSequence + nbBytes); while (currPos && rawSeqStore->pos < rawSeqStore->size) { @@ -935,7 +973,7 @@ ZSTD_opt_getNextMatchAndUpdateSeqStore(ZSTD_optLdm_t* optLdm, U32 currPosInBlock return; } - /* Matches may be < MINMATCH by this process. In that case, we will reject them + /* Matches may be < minMatch by this process. In that case, we will reject them when we are deciding whether or not to add the ldm */ optLdm->startPosInBlock = currPosInBlock + literalsBytesRemaining; optLdm->endPosInBlock = optLdm->startPosInBlock + matchBytesRemaining; @@ -957,25 +995,26 @@ ZSTD_opt_getNextMatchAndUpdateSeqStore(ZSTD_optLdm_t* optLdm, U32 currPosInBlock * into 'matches'. Maintains the correct ordering of 'matches'. */ static void ZSTD_optLdm_maybeAddMatch(ZSTD_match_t* matches, U32* nbMatches, - const ZSTD_optLdm_t* optLdm, U32 currPosInBlock) + const ZSTD_optLdm_t* optLdm, U32 currPosInBlock, + U32 minMatch) { U32 const posDiff = currPosInBlock - optLdm->startPosInBlock; - /* Note: ZSTD_match_t actually contains offCode and matchLength (before subtracting MINMATCH) */ + /* Note: ZSTD_match_t actually contains offBase and matchLength (before subtracting MINMATCH) */ U32 const candidateMatchLength = optLdm->endPosInBlock - optLdm->startPosInBlock - posDiff; /* Ensure that current block position is not outside of the match */ if (currPosInBlock < optLdm->startPosInBlock || currPosInBlock >= optLdm->endPosInBlock - || candidateMatchLength < MINMATCH) { + || candidateMatchLength < minMatch) { return; } if (*nbMatches == 0 || ((candidateMatchLength > matches[*nbMatches-1].len) && *nbMatches < ZSTD_OPT_NUM)) { - U32 const candidateOffCode = STORE_OFFSET(optLdm->offset); - DEBUGLOG(6, "ZSTD_optLdm_maybeAddMatch(): Adding ldm candidate match (offCode: %u matchLength %u) at block position=%u", - candidateOffCode, candidateMatchLength, currPosInBlock); + U32 const candidateOffBase = OFFSET_TO_OFFBASE(optLdm->offset); + DEBUGLOG(6, "ZSTD_optLdm_maybeAddMatch(): Adding ldm candidate match (offBase: %u matchLength %u) at block position=%u", + candidateOffBase, candidateMatchLength, currPosInBlock); matches[*nbMatches].len = candidateMatchLength; - matches[*nbMatches].off = candidateOffCode; + matches[*nbMatches].off = candidateOffBase; (*nbMatches)++; } } @@ -986,7 +1025,8 @@ static void ZSTD_optLdm_maybeAddMatch(ZSTD_match_t* matches, U32* nbMatches, static void ZSTD_optLdm_processMatchCandidate(ZSTD_optLdm_t* optLdm, ZSTD_match_t* matches, U32* nbMatches, - U32 currPosInBlock, U32 remainingBytes) + U32 currPosInBlock, U32 remainingBytes, + U32 minMatch) { if (optLdm->seqStore.size == 0 || optLdm->seqStore.pos >= optLdm->seqStore.size) { return; @@ -1003,7 +1043,7 @@ ZSTD_optLdm_processMatchCandidate(ZSTD_optLdm_t* optLdm, } ZSTD_opt_getNextMatchAndUpdateSeqStore(optLdm, currPosInBlock, remainingBytes); } - ZSTD_optLdm_maybeAddMatch(matches, nbMatches, optLdm, currPosInBlock); + ZSTD_optLdm_maybeAddMatch(matches, nbMatches, optLdm, currPosInBlock, minMatch); } @@ -1011,11 +1051,6 @@ ZSTD_optLdm_processMatchCandidate(ZSTD_optLdm_t* optLdm, * Optimal parser *********************************/ -static U32 ZSTD_totalLen(ZSTD_optimal_t sol) -{ - return sol.litlen + sol.mlen; -} - #if 0 /* debug */ static void @@ -1033,9 +1068,15 @@ listStats(const U32* table, int lastEltID) #endif -FORCE_INLINE_TEMPLATE size_t -ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, - seqStore_t* seqStore, +#define LIT_PRICE(_p) (int)ZSTD_rawLiteralsCost(_p, 1, optStatePtr, optLevel) +#define LL_PRICE(_l) (int)ZSTD_litLengthPrice(_l, optStatePtr, optLevel) +#define LL_INCPRICE(_l) (LL_PRICE(_l) - LL_PRICE(_l-1)) + +FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +size_t +ZSTD_compressBlock_opt_generic(ZSTD_MatchState_t* ms, + SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], const void* src, size_t srcSize, const int optLevel, @@ -1059,9 +1100,11 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, ZSTD_optimal_t* const opt = optStatePtr->priceTable; ZSTD_match_t* const matches = optStatePtr->matchTable; - ZSTD_optimal_t lastSequence; + ZSTD_optimal_t lastStretch; ZSTD_optLdm_t optLdm; + ZSTD_memset(&lastStretch, 0, sizeof(ZSTD_optimal_t)); + optLdm.seqStore = ms->ldmSeqStore ? *ms->ldmSeqStore : kNullRawSeqStore; optLdm.endPosInBlock = optLdm.startPosInBlock = optLdm.offset = 0; ZSTD_opt_getNextMatchAndUpdateSeqStore(&optLdm, (U32)(ip-istart), (U32)(iend-ip)); @@ -1082,103 +1125,140 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, U32 const ll0 = !litlen; U32 nbMatches = getAllMatches(matches, ms, &nextToUpdate3, ip, iend, rep, ll0, minMatch); ZSTD_optLdm_processMatchCandidate(&optLdm, matches, &nbMatches, - (U32)(ip-istart), (U32)(iend - ip)); - if (!nbMatches) { ip++; continue; } + (U32)(ip-istart), (U32)(iend-ip), + minMatch); + if (!nbMatches) { + DEBUGLOG(8, "no match found at cPos %u", (unsigned)(ip-istart)); + ip++; + continue; + } + + /* Match found: let's store this solution, and eventually find more candidates. + * During this forward pass, @opt is used to store stretches, + * defined as "a match followed by N literals". + * Note how this is different from a Sequence, which is "N literals followed by a match". + * Storing stretches allows us to store different match predecessors + * for each literal position part of a literals run. */ /* initialize opt[0] */ - { U32 i ; for (i=0; i immediate encoding */ { U32 const maxML = matches[nbMatches-1].len; - U32 const maxOffcode = matches[nbMatches-1].off; - DEBUGLOG(6, "found %u matches of maxLength=%u and maxOffCode=%u at cPos=%u => start new series", - nbMatches, maxML, maxOffcode, (U32)(ip-prefixStart)); + U32 const maxOffBase = matches[nbMatches-1].off; + DEBUGLOG(6, "found %u matches of maxLength=%u and maxOffBase=%u at cPos=%u => start new series", + nbMatches, maxML, maxOffBase, (U32)(ip-prefixStart)); if (maxML > sufficient_len) { - lastSequence.litlen = litlen; - lastSequence.mlen = maxML; - lastSequence.off = maxOffcode; - DEBUGLOG(6, "large match (%u>%u), immediate encoding", + lastStretch.litlen = 0; + lastStretch.mlen = maxML; + lastStretch.off = maxOffBase; + DEBUGLOG(6, "large match (%u>%u) => immediate encoding", maxML, sufficient_len); cur = 0; - last_pos = ZSTD_totalLen(lastSequence); + last_pos = maxML; goto _shortestPath; } } /* set prices for first matches starting position == 0 */ assert(opt[0].price >= 0); - { U32 const literalsPrice = (U32)opt[0].price + ZSTD_litLengthPrice(0, optStatePtr, optLevel); - U32 pos; + { U32 pos; U32 matchNb; for (pos = 1; pos < minMatch; pos++) { - opt[pos].price = ZSTD_MAX_PRICE; /* mlen, litlen and price will be fixed during forward scanning */ + opt[pos].price = ZSTD_MAX_PRICE; + opt[pos].mlen = 0; + opt[pos].litlen = litlen + pos; } for (matchNb = 0; matchNb < nbMatches; matchNb++) { - U32 const offcode = matches[matchNb].off; + U32 const offBase = matches[matchNb].off; U32 const end = matches[matchNb].len; for ( ; pos <= end ; pos++ ) { - U32 const matchPrice = ZSTD_getMatchPrice(offcode, pos, optStatePtr, optLevel); - U32 const sequencePrice = literalsPrice + matchPrice; + int const matchPrice = (int)ZSTD_getMatchPrice(offBase, pos, optStatePtr, optLevel); + int const sequencePrice = opt[0].price + matchPrice; DEBUGLOG(7, "rPos:%u => set initial price : %.2f", pos, ZSTD_fCost(sequencePrice)); opt[pos].mlen = pos; - opt[pos].off = offcode; - opt[pos].litlen = litlen; - opt[pos].price = (int)sequencePrice; - } } + opt[pos].off = offBase; + opt[pos].litlen = 0; /* end of match */ + opt[pos].price = sequencePrice + LL_PRICE(0); + } + } last_pos = pos-1; + opt[pos].price = ZSTD_MAX_PRICE; } } /* check further positions */ for (cur = 1; cur <= last_pos; cur++) { const BYTE* const inr = ip + cur; - assert(cur < ZSTD_OPT_NUM); - DEBUGLOG(7, "cPos:%zi==rPos:%u", inr-istart, cur) + assert(cur <= ZSTD_OPT_NUM); + DEBUGLOG(7, "cPos:%i==rPos:%u", (int)(inr-istart), cur); /* Fix current position with one literal if cheaper */ - { U32 const litlen = (opt[cur-1].mlen == 0) ? opt[cur-1].litlen + 1 : 1; + { U32 const litlen = opt[cur-1].litlen + 1; int const price = opt[cur-1].price - + (int)ZSTD_rawLiteralsCost(ip+cur-1, 1, optStatePtr, optLevel) - + (int)ZSTD_litLengthPrice(litlen, optStatePtr, optLevel) - - (int)ZSTD_litLengthPrice(litlen-1, optStatePtr, optLevel); + + LIT_PRICE(ip+cur-1) + + LL_INCPRICE(litlen); assert(price < 1000000000); /* overflow check */ if (price <= opt[cur].price) { - DEBUGLOG(7, "cPos:%zi==rPos:%u : better price (%.2f<=%.2f) using literal (ll==%u) (hist:%u,%u,%u)", - inr-istart, cur, ZSTD_fCost(price), ZSTD_fCost(opt[cur].price), litlen, + ZSTD_optimal_t const prevMatch = opt[cur]; + DEBUGLOG(7, "cPos:%i==rPos:%u : better price (%.2f<=%.2f) using literal (ll==%u) (hist:%u,%u,%u)", + (int)(inr-istart), cur, ZSTD_fCost(price), ZSTD_fCost(opt[cur].price), litlen, opt[cur-1].rep[0], opt[cur-1].rep[1], opt[cur-1].rep[2]); - opt[cur].mlen = 0; - opt[cur].off = 0; + opt[cur] = opt[cur-1]; opt[cur].litlen = litlen; opt[cur].price = price; + if ( (optLevel >= 1) /* additional check only for higher modes */ + && (prevMatch.litlen == 0) /* replace a match */ + && (LL_INCPRICE(1) < 0) /* ll1 is cheaper than ll0 */ + && LIKELY(ip + cur < iend) + ) { + /* check next position, in case it would be cheaper */ + int with1literal = prevMatch.price + LIT_PRICE(ip+cur) + LL_INCPRICE(1); + int withMoreLiterals = price + LIT_PRICE(ip+cur) + LL_INCPRICE(litlen+1); + DEBUGLOG(7, "then at next rPos %u : match+1lit %.2f vs %ulits %.2f", + cur+1, ZSTD_fCost(with1literal), litlen+1, ZSTD_fCost(withMoreLiterals)); + if ( (with1literal < withMoreLiterals) + && (with1literal < opt[cur+1].price) ) { + /* update offset history - before it disappears */ + U32 const prev = cur - prevMatch.mlen; + Repcodes_t const newReps = ZSTD_newRep(opt[prev].rep, prevMatch.off, opt[prev].litlen==0); + assert(cur >= prevMatch.mlen); + DEBUGLOG(7, "==> match+1lit is cheaper (%.2f < %.2f) (hist:%u,%u,%u) !", + ZSTD_fCost(with1literal), ZSTD_fCost(withMoreLiterals), + newReps.rep[0], newReps.rep[1], newReps.rep[2] ); + opt[cur+1] = prevMatch; /* mlen & offbase */ + ZSTD_memcpy(opt[cur+1].rep, &newReps, sizeof(Repcodes_t)); + opt[cur+1].litlen = 1; + opt[cur+1].price = with1literal; + if (last_pos < cur+1) last_pos = cur+1; + } + } } else { - DEBUGLOG(7, "cPos:%zi==rPos:%u : literal would cost more (%.2f>%.2f) (hist:%u,%u,%u)", - inr-istart, cur, ZSTD_fCost(price), ZSTD_fCost(opt[cur].price), - opt[cur].rep[0], opt[cur].rep[1], opt[cur].rep[2]); + DEBUGLOG(7, "cPos:%i==rPos:%u : literal would cost more (%.2f>%.2f)", + (int)(inr-istart), cur, ZSTD_fCost(price), ZSTD_fCost(opt[cur].price)); } } - /* Set the repcodes of the current position. We must do it here - * because we rely on the repcodes of the 2nd to last sequence being - * correct to set the next chunks repcodes during the backward - * traversal. + /* Offset history is not updated during match comparison. + * Do it here, now that the match is selected and confirmed. */ - ZSTD_STATIC_ASSERT(sizeof(opt[cur].rep) == sizeof(repcodes_t)); + ZSTD_STATIC_ASSERT(sizeof(opt[cur].rep) == sizeof(Repcodes_t)); assert(cur >= opt[cur].mlen); - if (opt[cur].mlen != 0) { + if (opt[cur].litlen == 0) { + /* just finished a match => alter offset history */ U32 const prev = cur - opt[cur].mlen; - repcodes_t const newReps = ZSTD_newRep(opt[prev].rep, opt[cur].off, opt[cur].litlen==0); - ZSTD_memcpy(opt[cur].rep, &newReps, sizeof(repcodes_t)); - } else { - ZSTD_memcpy(opt[cur].rep, opt[cur - 1].rep, sizeof(repcodes_t)); + Repcodes_t const newReps = ZSTD_newRep(opt[prev].rep, opt[cur].off, opt[prev].litlen==0); + ZSTD_memcpy(opt[cur].rep, &newReps, sizeof(Repcodes_t)); } /* last match must start at a minimum distance of 8 from oend */ @@ -1188,38 +1268,37 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, if ( (optLevel==0) /*static_test*/ && (opt[cur+1].price <= opt[cur].price + (BITCOST_MULTIPLIER/2)) ) { - DEBUGLOG(7, "move to next rPos:%u : price is <=", cur+1); + DEBUGLOG(7, "skip current position : next rPos(%u) price is cheaper", cur+1); continue; /* skip unpromising positions; about ~+6% speed, -0.01 ratio */ } assert(opt[cur].price >= 0); - { U32 const ll0 = (opt[cur].mlen != 0); - U32 const litlen = (opt[cur].mlen == 0) ? opt[cur].litlen : 0; - U32 const previousPrice = (U32)opt[cur].price; - U32 const basePrice = previousPrice + ZSTD_litLengthPrice(0, optStatePtr, optLevel); + { U32 const ll0 = (opt[cur].litlen == 0); + int const previousPrice = opt[cur].price; + int const basePrice = previousPrice + LL_PRICE(0); U32 nbMatches = getAllMatches(matches, ms, &nextToUpdate3, inr, iend, opt[cur].rep, ll0, minMatch); U32 matchNb; ZSTD_optLdm_processMatchCandidate(&optLdm, matches, &nbMatches, - (U32)(inr-istart), (U32)(iend-inr)); + (U32)(inr-istart), (U32)(iend-inr), + minMatch); if (!nbMatches) { DEBUGLOG(7, "rPos:%u : no match found", cur); continue; } - { U32 const maxML = matches[nbMatches-1].len; - DEBUGLOG(7, "cPos:%zi==rPos:%u, found %u matches, of maxLength=%u", - inr-istart, cur, nbMatches, maxML); - - if ( (maxML > sufficient_len) - || (cur + maxML >= ZSTD_OPT_NUM) ) { - lastSequence.mlen = maxML; - lastSequence.off = matches[nbMatches-1].off; - lastSequence.litlen = litlen; - cur -= (opt[cur].mlen==0) ? opt[cur].litlen : 0; /* last sequence is actually only literals, fix cur to last match - note : may underflow, in which case, it's first sequence, and it's okay */ - last_pos = cur + ZSTD_totalLen(lastSequence); - if (cur > ZSTD_OPT_NUM) cur = 0; /* underflow => first match */ + { U32 const longestML = matches[nbMatches-1].len; + DEBUGLOG(7, "cPos:%i==rPos:%u, found %u matches, of longest ML=%u", + (int)(inr-istart), cur, nbMatches, longestML); + + if ( (longestML > sufficient_len) + || (cur + longestML >= ZSTD_OPT_NUM) + || (ip + cur + longestML >= iend) ) { + lastStretch.mlen = longestML; + lastStretch.off = matches[nbMatches-1].off; + lastStretch.litlen = 0; + last_pos = cur + longestML; goto _shortestPath; } } @@ -1230,20 +1309,25 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, U32 const startML = (matchNb>0) ? matches[matchNb-1].len+1 : minMatch; U32 mlen; - DEBUGLOG(7, "testing match %u => offCode=%4u, mlen=%2u, llen=%2u", - matchNb, matches[matchNb].off, lastML, litlen); + DEBUGLOG(7, "testing match %u => offBase=%4u, mlen=%2u, llen=%2u", + matchNb, matches[matchNb].off, lastML, opt[cur].litlen); for (mlen = lastML; mlen >= startML; mlen--) { /* scan downward */ U32 const pos = cur + mlen; - int const price = (int)basePrice + (int)ZSTD_getMatchPrice(offset, mlen, optStatePtr, optLevel); + int const price = basePrice + (int)ZSTD_getMatchPrice(offset, mlen, optStatePtr, optLevel); if ((pos > last_pos) || (price < opt[pos].price)) { DEBUGLOG(7, "rPos:%u (ml=%2u) => new better price (%.2f<%.2f)", pos, mlen, ZSTD_fCost(price), ZSTD_fCost(opt[pos].price)); - while (last_pos < pos) { opt[last_pos+1].price = ZSTD_MAX_PRICE; last_pos++; } /* fill empty positions */ + while (last_pos < pos) { + /* fill empty positions, for future comparisons */ + last_pos++; + opt[last_pos].price = ZSTD_MAX_PRICE; + opt[last_pos].litlen = !0; /* just needs to be != 0, to mean "not an end of match" */ + } opt[pos].mlen = mlen; opt[pos].off = offset; - opt[pos].litlen = litlen; + opt[pos].litlen = 0; opt[pos].price = price; } else { DEBUGLOG(7, "rPos:%u (ml=%2u) => new price is worse (%.2f>=%.2f)", @@ -1251,55 +1335,89 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, if (optLevel==0) break; /* early update abort; gets ~+10% speed for about -0.01 ratio loss */ } } } } + opt[last_pos+1].price = ZSTD_MAX_PRICE; } /* for (cur = 1; cur <= last_pos; cur++) */ - lastSequence = opt[last_pos]; - cur = last_pos > ZSTD_totalLen(lastSequence) ? last_pos - ZSTD_totalLen(lastSequence) : 0; /* single sequence, and it starts before `ip` */ - assert(cur < ZSTD_OPT_NUM); /* control overflow*/ + lastStretch = opt[last_pos]; + assert(cur >= lastStretch.mlen); + cur = last_pos - lastStretch.mlen; _shortestPath: /* cur, last_pos, best_mlen, best_off have to be set */ assert(opt[0].mlen == 0); + assert(last_pos >= lastStretch.mlen); + assert(cur == last_pos - lastStretch.mlen); - /* Set the next chunk's repcodes based on the repcodes of the beginning - * of the last match, and the last sequence. This avoids us having to - * update them while traversing the sequences. - */ - if (lastSequence.mlen != 0) { - repcodes_t const reps = ZSTD_newRep(opt[cur].rep, lastSequence.off, lastSequence.litlen==0); - ZSTD_memcpy(rep, &reps, sizeof(reps)); + if (lastStretch.mlen==0) { + /* no solution : all matches have been converted into literals */ + assert(lastStretch.litlen == (ip - anchor) + last_pos); + ip += last_pos; + continue; + } + assert(lastStretch.off > 0); + + /* Update offset history */ + if (lastStretch.litlen == 0) { + /* finishing on a match : update offset history */ + Repcodes_t const reps = ZSTD_newRep(opt[cur].rep, lastStretch.off, opt[cur].litlen==0); + ZSTD_memcpy(rep, &reps, sizeof(Repcodes_t)); } else { - ZSTD_memcpy(rep, opt[cur].rep, sizeof(repcodes_t)); + ZSTD_memcpy(rep, lastStretch.rep, sizeof(Repcodes_t)); + assert(cur >= lastStretch.litlen); + cur -= lastStretch.litlen; } - { U32 const storeEnd = cur + 1; + /* Let's write the shortest path solution. + * It is stored in @opt in reverse order, + * starting from @storeEnd (==cur+2), + * effectively partially @opt overwriting. + * Content is changed too: + * - So far, @opt stored stretches, aka a match followed by literals + * - Now, it will store sequences, aka literals followed by a match + */ + { U32 const storeEnd = cur + 2; U32 storeStart = storeEnd; - U32 seqPos = cur; + U32 stretchPos = cur; DEBUGLOG(6, "start reverse traversal (last_pos:%u, cur:%u)", last_pos, cur); (void)last_pos; - assert(storeEnd < ZSTD_OPT_NUM); - DEBUGLOG(6, "last sequence copied into pos=%u (llen=%u,mlen=%u,ofc=%u)", - storeEnd, lastSequence.litlen, lastSequence.mlen, lastSequence.off); - opt[storeEnd] = lastSequence; - while (seqPos > 0) { - U32 const backDist = ZSTD_totalLen(opt[seqPos]); + assert(storeEnd < ZSTD_OPT_SIZE); + DEBUGLOG(6, "last stretch copied into pos=%u (llen=%u,mlen=%u,ofc=%u)", + storeEnd, lastStretch.litlen, lastStretch.mlen, lastStretch.off); + if (lastStretch.litlen > 0) { + /* last "sequence" is unfinished: just a bunch of literals */ + opt[storeEnd].litlen = lastStretch.litlen; + opt[storeEnd].mlen = 0; + storeStart = storeEnd-1; + opt[storeStart] = lastStretch; + } { + opt[storeEnd] = lastStretch; /* note: litlen will be fixed */ + storeStart = storeEnd; + } + while (1) { + ZSTD_optimal_t nextStretch = opt[stretchPos]; + opt[storeStart].litlen = nextStretch.litlen; + DEBUGLOG(6, "selected sequence (llen=%u,mlen=%u,ofc=%u)", + opt[storeStart].litlen, opt[storeStart].mlen, opt[storeStart].off); + if (nextStretch.mlen == 0) { + /* reaching beginning of segment */ + break; + } storeStart--; - DEBUGLOG(6, "sequence from rPos=%u copied into pos=%u (llen=%u,mlen=%u,ofc=%u)", - seqPos, storeStart, opt[seqPos].litlen, opt[seqPos].mlen, opt[seqPos].off); - opt[storeStart] = opt[seqPos]; - seqPos = (seqPos > backDist) ? seqPos - backDist : 0; + opt[storeStart] = nextStretch; /* note: litlen will be fixed */ + assert(nextStretch.litlen + nextStretch.mlen <= stretchPos); + stretchPos -= nextStretch.litlen + nextStretch.mlen; } /* save sequences */ - DEBUGLOG(6, "sending selected sequences into seqStore") + DEBUGLOG(6, "sending selected sequences into seqStore"); { U32 storePos; for (storePos=storeStart; storePos <= storeEnd; storePos++) { U32 const llen = opt[storePos].litlen; U32 const mlen = opt[storePos].mlen; - U32 const offCode = opt[storePos].off; + U32 const offBase = opt[storePos].off; U32 const advance = llen + mlen; - DEBUGLOG(6, "considering seq starting at %zi, llen=%u, mlen=%u", - anchor - istart, (unsigned)llen, (unsigned)mlen); + DEBUGLOG(6, "considering seq starting at %i, llen=%u, mlen=%u", + (int)(anchor - istart), (unsigned)llen, (unsigned)mlen); if (mlen==0) { /* only literals => must be last "sequence", actually starting a new stream of sequences */ assert(storePos == storeEnd); /* must be last sequence */ @@ -1308,11 +1426,14 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, } assert(anchor + llen <= iend); - ZSTD_updateStats(optStatePtr, llen, anchor, offCode, mlen); - ZSTD_storeSeq(seqStore, llen, anchor, iend, offCode, mlen); + ZSTD_updateStats(optStatePtr, llen, anchor, offBase, mlen); + ZSTD_storeSeq(seqStore, llen, anchor, iend, offBase, mlen); anchor += advance; ip = anchor; } } + DEBUGLOG(7, "new offset history : %u, %u, %u", rep[0], rep[1], rep[2]); + + /* update all costs */ ZSTD_setBasePrices(optStatePtr, optLevel); } } /* while (ip < ilimit) */ @@ -1320,42 +1441,51 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, /* Return the last literals size */ return (size_t)(iend - anchor); } +#endif /* build exclusions */ +#ifndef ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR static size_t ZSTD_compressBlock_opt0( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], const void* src, size_t srcSize, const ZSTD_dictMode_e dictMode) { return ZSTD_compressBlock_opt_generic(ms, seqStore, rep, src, srcSize, 0 /* optLevel */, dictMode); } +#endif +#ifndef ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR static size_t ZSTD_compressBlock_opt2( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], const void* src, size_t srcSize, const ZSTD_dictMode_e dictMode) { return ZSTD_compressBlock_opt_generic(ms, seqStore, rep, src, srcSize, 2 /* optLevel */, dictMode); } +#endif +#ifndef ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR size_t ZSTD_compressBlock_btopt( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], const void* src, size_t srcSize) { DEBUGLOG(5, "ZSTD_compressBlock_btopt"); return ZSTD_compressBlock_opt0(ms, seqStore, rep, src, srcSize, ZSTD_noDict); } +#endif +#ifndef ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR /* ZSTD_initStats_ultra(): * make a first compression pass, just to seed stats with more accurate starting values. * only works on first block, with no dictionary and no ldm. - * this function cannot error, hence its contract must be respected. + * this function cannot error out, its narrow contract must be respected. */ -static void -ZSTD_initStats_ultra(ZSTD_matchState_t* ms, - seqStore_t* seqStore, - U32 rep[ZSTD_REP_NUM], - const void* src, size_t srcSize) +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +void ZSTD_initStats_ultra(ZSTD_MatchState_t* ms, + SeqStore_t* seqStore, + U32 rep[ZSTD_REP_NUM], + const void* src, size_t srcSize) { U32 tmpRep[ZSTD_REP_NUM]; /* updated rep codes will sink here */ ZSTD_memcpy(tmpRep, rep, sizeof(tmpRep)); @@ -1368,7 +1498,7 @@ ZSTD_initStats_ultra(ZSTD_matchState_t* ms, ZSTD_compressBlock_opt2(ms, seqStore, tmpRep, src, srcSize, ZSTD_noDict); /* generate stats into ms->opt*/ - /* invalidate first scan from history */ + /* invalidate first scan from history, only keep entropy stats */ ZSTD_resetSeqStore(seqStore); ms->window.base -= srcSize; ms->window.dictLimit += (U32)srcSize; @@ -1378,7 +1508,7 @@ ZSTD_initStats_ultra(ZSTD_matchState_t* ms, } size_t ZSTD_compressBlock_btultra( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], const void* src, size_t srcSize) { DEBUGLOG(5, "ZSTD_compressBlock_btultra (srcSize=%zu)", srcSize); @@ -1386,16 +1516,16 @@ size_t ZSTD_compressBlock_btultra( } size_t ZSTD_compressBlock_btultra2( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], const void* src, size_t srcSize) { U32 const curr = (U32)((const BYTE*)src - ms->window.base); DEBUGLOG(5, "ZSTD_compressBlock_btultra2 (srcSize=%zu)", srcSize); - /* 2-pass strategy: + /* 2-passes strategy: * this strategy makes a first pass over first block to collect statistics - * and seed next round's statistics with it. - * After 1st pass, function forgets everything, and starts a new block. + * in order to seed next round's statistics with it. + * After 1st pass, function forgets history, and starts a new block. * Consequently, this can only work if no data has been previously loaded in tables, * aka, no dictionary, no prefix, no ldm preprocessing. * The compression ratio gain is generally small (~0.5% on first block), @@ -1404,42 +1534,47 @@ size_t ZSTD_compressBlock_btultra2( if ( (ms->opt.litLengthSum==0) /* first block */ && (seqStore->sequences == seqStore->sequencesStart) /* no ldm */ && (ms->window.dictLimit == ms->window.lowLimit) /* no dictionary */ - && (curr == ms->window.dictLimit) /* start of frame, nothing already loaded nor skipped */ - && (srcSize > ZSTD_PREDEF_THRESHOLD) + && (curr == ms->window.dictLimit) /* start of frame, nothing already loaded nor skipped */ + && (srcSize > ZSTD_PREDEF_THRESHOLD) /* input large enough to not employ default stats */ ) { ZSTD_initStats_ultra(ms, seqStore, rep, src, srcSize); } return ZSTD_compressBlock_opt2(ms, seqStore, rep, src, srcSize, ZSTD_noDict); } +#endif +#ifndef ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR size_t ZSTD_compressBlock_btopt_dictMatchState( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], const void* src, size_t srcSize) { return ZSTD_compressBlock_opt0(ms, seqStore, rep, src, srcSize, ZSTD_dictMatchState); } -size_t ZSTD_compressBlock_btultra_dictMatchState( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_btopt_extDict( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], const void* src, size_t srcSize) { - return ZSTD_compressBlock_opt2(ms, seqStore, rep, src, srcSize, ZSTD_dictMatchState); + return ZSTD_compressBlock_opt0(ms, seqStore, rep, src, srcSize, ZSTD_extDict); } +#endif -size_t ZSTD_compressBlock_btopt_extDict( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +#ifndef ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_btultra_dictMatchState( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], const void* src, size_t srcSize) { - return ZSTD_compressBlock_opt0(ms, seqStore, rep, src, srcSize, ZSTD_extDict); + return ZSTD_compressBlock_opt2(ms, seqStore, rep, src, srcSize, ZSTD_dictMatchState); } size_t ZSTD_compressBlock_btultra_extDict( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], const void* src, size_t srcSize) { return ZSTD_compressBlock_opt2(ms, seqStore, rep, src, srcSize, ZSTD_extDict); } +#endif /* note : no btultra2 variant for extDict nor dictMatchState, * because btultra2 is not meant to work with dictionaries diff --git a/lib/zstd/compress/zstd_opt.h b/lib/zstd/compress/zstd_opt.h index 22b862858ba7..fbdc540ec9d1 100644 --- a/lib/zstd/compress/zstd_opt.h +++ b/lib/zstd/compress/zstd_opt.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -11,40 +12,62 @@ #ifndef ZSTD_OPT_H #define ZSTD_OPT_H - #include "zstd_compress_internal.h" +#if !defined(ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR) /* used in ZSTD_loadDictionaryContent() */ -void ZSTD_updateTree(ZSTD_matchState_t* ms, const BYTE* ip, const BYTE* iend); +void ZSTD_updateTree(ZSTD_MatchState_t* ms, const BYTE* ip, const BYTE* iend); +#endif +#ifndef ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR size_t ZSTD_compressBlock_btopt( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_btultra( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_btopt_dictMatchState( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_btultra2( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +size_t ZSTD_compressBlock_btopt_extDict( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); +#define ZSTD_COMPRESSBLOCK_BTOPT ZSTD_compressBlock_btopt +#define ZSTD_COMPRESSBLOCK_BTOPT_DICTMATCHSTATE ZSTD_compressBlock_btopt_dictMatchState +#define ZSTD_COMPRESSBLOCK_BTOPT_EXTDICT ZSTD_compressBlock_btopt_extDict +#else +#define ZSTD_COMPRESSBLOCK_BTOPT NULL +#define ZSTD_COMPRESSBLOCK_BTOPT_DICTMATCHSTATE NULL +#define ZSTD_COMPRESSBLOCK_BTOPT_EXTDICT NULL +#endif -size_t ZSTD_compressBlock_btopt_dictMatchState( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], +#ifndef ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_btultra( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); size_t ZSTD_compressBlock_btultra_dictMatchState( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], - void const* src, size_t srcSize); - -size_t ZSTD_compressBlock_btopt_extDict( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); size_t ZSTD_compressBlock_btultra_extDict( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); /* note : no btultra2 variant for extDict nor dictMatchState, * because btultra2 is not meant to work with dictionaries * and is only specific for the first block (no prefix) */ +size_t ZSTD_compressBlock_btultra2( + ZSTD_MatchState_t* ms, SeqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); +#define ZSTD_COMPRESSBLOCK_BTULTRA ZSTD_compressBlock_btultra +#define ZSTD_COMPRESSBLOCK_BTULTRA_DICTMATCHSTATE ZSTD_compressBlock_btultra_dictMatchState +#define ZSTD_COMPRESSBLOCK_BTULTRA_EXTDICT ZSTD_compressBlock_btultra_extDict +#define ZSTD_COMPRESSBLOCK_BTULTRA2 ZSTD_compressBlock_btultra2 +#else +#define ZSTD_COMPRESSBLOCK_BTULTRA NULL +#define ZSTD_COMPRESSBLOCK_BTULTRA_DICTMATCHSTATE NULL +#define ZSTD_COMPRESSBLOCK_BTULTRA_EXTDICT NULL +#define ZSTD_COMPRESSBLOCK_BTULTRA2 NULL +#endif #endif /* ZSTD_OPT_H */ diff --git a/lib/zstd/compress/zstd_preSplit.c b/lib/zstd/compress/zstd_preSplit.c new file mode 100644 index 000000000000..7d9403c9a3bc --- /dev/null +++ b/lib/zstd/compress/zstd_preSplit.c @@ -0,0 +1,239 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the + * LICENSE file in the root directory of this source tree) and the GPLv2 (found + * in the COPYING file in the root directory of this source tree). + * You may select, at your option, one of the above-listed licenses. + */ + +#include "../common/compiler.h" /* ZSTD_ALIGNOF */ +#include "../common/mem.h" /* S64 */ +#include "../common/zstd_deps.h" /* ZSTD_memset */ +#include "../common/zstd_internal.h" /* ZSTD_STATIC_ASSERT */ +#include "hist.h" /* HIST_add */ +#include "zstd_preSplit.h" + + +#define BLOCKSIZE_MIN 3500 +#define THRESHOLD_PENALTY_RATE 16 +#define THRESHOLD_BASE (THRESHOLD_PENALTY_RATE - 2) +#define THRESHOLD_PENALTY 3 + +#define HASHLENGTH 2 +#define HASHLOG_MAX 10 +#define HASHTABLESIZE (1 << HASHLOG_MAX) +#define HASHMASK (HASHTABLESIZE - 1) +#define KNUTH 0x9e3779b9 + +/* for hashLog > 8, hash 2 bytes. + * for hashLog == 8, just take the byte, no hashing. + * The speed of this method relies on compile-time constant propagation */ +FORCE_INLINE_TEMPLATE unsigned hash2(const void *p, unsigned hashLog) +{ + assert(hashLog >= 8); + if (hashLog == 8) return (U32)((const BYTE*)p)[0]; + assert(hashLog <= HASHLOG_MAX); + return (U32)(MEM_read16(p)) * KNUTH >> (32 - hashLog); +} + + +typedef struct { + unsigned events[HASHTABLESIZE]; + size_t nbEvents; +} Fingerprint; +typedef struct { + Fingerprint pastEvents; + Fingerprint newEvents; +} FPStats; + +static void initStats(FPStats* fpstats) +{ + ZSTD_memset(fpstats, 0, sizeof(FPStats)); +} + +FORCE_INLINE_TEMPLATE void +addEvents_generic(Fingerprint* fp, const void* src, size_t srcSize, size_t samplingRate, unsigned hashLog) +{ + const char* p = (const char*)src; + size_t limit = srcSize - HASHLENGTH + 1; + size_t n; + assert(srcSize >= HASHLENGTH); + for (n = 0; n < limit; n+=samplingRate) { + fp->events[hash2(p+n, hashLog)]++; + } + fp->nbEvents += limit/samplingRate; +} + +FORCE_INLINE_TEMPLATE void +recordFingerprint_generic(Fingerprint* fp, const void* src, size_t srcSize, size_t samplingRate, unsigned hashLog) +{ + ZSTD_memset(fp, 0, sizeof(unsigned) * ((size_t)1 << hashLog)); + fp->nbEvents = 0; + addEvents_generic(fp, src, srcSize, samplingRate, hashLog); +} + +typedef void (*RecordEvents_f)(Fingerprint* fp, const void* src, size_t srcSize); + +#define FP_RECORD(_rate) ZSTD_recordFingerprint_##_rate + +#define ZSTD_GEN_RECORD_FINGERPRINT(_rate, _hSize) \ + static void FP_RECORD(_rate)(Fingerprint* fp, const void* src, size_t srcSize) \ + { \ + recordFingerprint_generic(fp, src, srcSize, _rate, _hSize); \ + } + +ZSTD_GEN_RECORD_FINGERPRINT(1, 10) +ZSTD_GEN_RECORD_FINGERPRINT(5, 10) +ZSTD_GEN_RECORD_FINGERPRINT(11, 9) +ZSTD_GEN_RECORD_FINGERPRINT(43, 8) + + +static U64 abs64(S64 s64) { return (U64)((s64 < 0) ? -s64 : s64); } + +static U64 fpDistance(const Fingerprint* fp1, const Fingerprint* fp2, unsigned hashLog) +{ + U64 distance = 0; + size_t n; + assert(hashLog <= HASHLOG_MAX); + for (n = 0; n < ((size_t)1 << hashLog); n++) { + distance += + abs64((S64)fp1->events[n] * (S64)fp2->nbEvents - (S64)fp2->events[n] * (S64)fp1->nbEvents); + } + return distance; +} + +/* Compare newEvents with pastEvents + * return 1 when considered "too different" + */ +static int compareFingerprints(const Fingerprint* ref, + const Fingerprint* newfp, + int penalty, + unsigned hashLog) +{ + assert(ref->nbEvents > 0); + assert(newfp->nbEvents > 0); + { U64 p50 = (U64)ref->nbEvents * (U64)newfp->nbEvents; + U64 deviation = fpDistance(ref, newfp, hashLog); + U64 threshold = p50 * (U64)(THRESHOLD_BASE + penalty) / THRESHOLD_PENALTY_RATE; + return deviation >= threshold; + } +} + +static void mergeEvents(Fingerprint* acc, const Fingerprint* newfp) +{ + size_t n; + for (n = 0; n < HASHTABLESIZE; n++) { + acc->events[n] += newfp->events[n]; + } + acc->nbEvents += newfp->nbEvents; +} + +static void flushEvents(FPStats* fpstats) +{ + size_t n; + for (n = 0; n < HASHTABLESIZE; n++) { + fpstats->pastEvents.events[n] = fpstats->newEvents.events[n]; + } + fpstats->pastEvents.nbEvents = fpstats->newEvents.nbEvents; + ZSTD_memset(&fpstats->newEvents, 0, sizeof(fpstats->newEvents)); +} + +static void removeEvents(Fingerprint* acc, const Fingerprint* slice) +{ + size_t n; + for (n = 0; n < HASHTABLESIZE; n++) { + assert(acc->events[n] >= slice->events[n]); + acc->events[n] -= slice->events[n]; + } + acc->nbEvents -= slice->nbEvents; +} + +#define CHUNKSIZE (8 << 10) +static size_t ZSTD_splitBlock_byChunks(const void* blockStart, size_t blockSize, + int level, + void* workspace, size_t wkspSize) +{ + static const RecordEvents_f records_fs[] = { + FP_RECORD(43), FP_RECORD(11), FP_RECORD(5), FP_RECORD(1) + }; + static const unsigned hashParams[] = { 8, 9, 10, 10 }; + const RecordEvents_f record_f = (assert(0<=level && level<=3), records_fs[level]); + FPStats* const fpstats = (FPStats*)workspace; + const char* p = (const char*)blockStart; + int penalty = THRESHOLD_PENALTY; + size_t pos = 0; + assert(blockSize == (128 << 10)); + assert(workspace != NULL); + assert((size_t)workspace % ZSTD_ALIGNOF(FPStats) == 0); + ZSTD_STATIC_ASSERT(ZSTD_SLIPBLOCK_WORKSPACESIZE >= sizeof(FPStats)); + assert(wkspSize >= sizeof(FPStats)); (void)wkspSize; + + initStats(fpstats); + record_f(&fpstats->pastEvents, p, CHUNKSIZE); + for (pos = CHUNKSIZE; pos <= blockSize - CHUNKSIZE; pos += CHUNKSIZE) { + record_f(&fpstats->newEvents, p + pos, CHUNKSIZE); + if (compareFingerprints(&fpstats->pastEvents, &fpstats->newEvents, penalty, hashParams[level])) { + return pos; + } else { + mergeEvents(&fpstats->pastEvents, &fpstats->newEvents); + if (penalty > 0) penalty--; + } + } + assert(pos == blockSize); + return blockSize; + (void)flushEvents; (void)removeEvents; +} + +/* ZSTD_splitBlock_fromBorders(): very fast strategy : + * compare fingerprint from beginning and end of the block, + * derive from their difference if it's preferable to split in the middle, + * repeat the process a second time, for finer grained decision. + * 3 times did not brought improvements, so I stopped at 2. + * Benefits are good enough for a cheap heuristic. + * More accurate splitting saves more, but speed impact is also more perceptible. + * For better accuracy, use more elaborate variant *_byChunks. + */ +static size_t ZSTD_splitBlock_fromBorders(const void* blockStart, size_t blockSize, + void* workspace, size_t wkspSize) +{ +#define SEGMENT_SIZE 512 + FPStats* const fpstats = (FPStats*)workspace; + Fingerprint* middleEvents = (Fingerprint*)(void*)((char*)workspace + 512 * sizeof(unsigned)); + assert(blockSize == (128 << 10)); + assert(workspace != NULL); + assert((size_t)workspace % ZSTD_ALIGNOF(FPStats) == 0); + ZSTD_STATIC_ASSERT(ZSTD_SLIPBLOCK_WORKSPACESIZE >= sizeof(FPStats)); + assert(wkspSize >= sizeof(FPStats)); (void)wkspSize; + + initStats(fpstats); + HIST_add(fpstats->pastEvents.events, blockStart, SEGMENT_SIZE); + HIST_add(fpstats->newEvents.events, (const char*)blockStart + blockSize - SEGMENT_SIZE, SEGMENT_SIZE); + fpstats->pastEvents.nbEvents = fpstats->newEvents.nbEvents = SEGMENT_SIZE; + if (!compareFingerprints(&fpstats->pastEvents, &fpstats->newEvents, 0, 8)) + return blockSize; + + HIST_add(middleEvents->events, (const char*)blockStart + blockSize/2 - SEGMENT_SIZE/2, SEGMENT_SIZE); + middleEvents->nbEvents = SEGMENT_SIZE; + { U64 const distFromBegin = fpDistance(&fpstats->pastEvents, middleEvents, 8); + U64 const distFromEnd = fpDistance(&fpstats->newEvents, middleEvents, 8); + U64 const minDistance = SEGMENT_SIZE * SEGMENT_SIZE / 3; + if (abs64((S64)distFromBegin - (S64)distFromEnd) < minDistance) + return 64 KB; + return (distFromBegin > distFromEnd) ? 32 KB : 96 KB; + } +} + +size_t ZSTD_splitBlock(const void* blockStart, size_t blockSize, + int level, + void* workspace, size_t wkspSize) +{ + DEBUGLOG(6, "ZSTD_splitBlock (level=%i)", level); + assert(0<=level && level<=4); + if (level == 0) + return ZSTD_splitBlock_fromBorders(blockStart, blockSize, workspace, wkspSize); + /* level >= 1*/ + return ZSTD_splitBlock_byChunks(blockStart, blockSize, level-1, workspace, wkspSize); +} diff --git a/lib/zstd/compress/zstd_preSplit.h b/lib/zstd/compress/zstd_preSplit.h new file mode 100644 index 000000000000..f98f797fe191 --- /dev/null +++ b/lib/zstd/compress/zstd_preSplit.h @@ -0,0 +1,34 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the + * LICENSE file in the root directory of this source tree) and the GPLv2 (found + * in the COPYING file in the root directory of this source tree). + * You may select, at your option, one of the above-listed licenses. + */ + +#ifndef ZSTD_PRESPLIT_H +#define ZSTD_PRESPLIT_H + +#include /* size_t */ + +#define ZSTD_SLIPBLOCK_WORKSPACESIZE 8208 + +/* ZSTD_splitBlock(): + * @level must be a value between 0 and 4. + * higher levels spend more energy to detect block boundaries. + * @workspace must be aligned for size_t. + * @wkspSize must be at least >= ZSTD_SLIPBLOCK_WORKSPACESIZE + * note: + * For the time being, this function only accepts full 128 KB blocks. + * Therefore, @blockSize must be == 128 KB. + * While this could be extended to smaller sizes in the future, + * it is not yet clear if this would be useful. TBD. + */ +size_t ZSTD_splitBlock(const void* blockStart, size_t blockSize, + int level, + void* workspace, size_t wkspSize); + +#endif /* ZSTD_PRESPLIT_H */ diff --git a/lib/zstd/decompress/huf_decompress.c b/lib/zstd/decompress/huf_decompress.c index 60958afebc41..ac8b87f48f84 100644 --- a/lib/zstd/decompress/huf_decompress.c +++ b/lib/zstd/decompress/huf_decompress.c @@ -1,7 +1,8 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* ****************************************************************** * huff0 huffman decoder, * part of Finite State Entropy library - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * * You can contact the author at : * - FSE+HUF source repository : https://github.com/Cyan4973/FiniteStateEntropy @@ -19,10 +20,10 @@ #include "../common/compiler.h" #include "../common/bitstream.h" /* BIT_* */ #include "../common/fse.h" /* to compress headers */ -#define HUF_STATIC_LINKING_ONLY #include "../common/huf.h" #include "../common/error_private.h" #include "../common/zstd_internal.h" +#include "../common/bits.h" /* ZSTD_highbit32, ZSTD_countTrailingZeros64 */ /* ************************************************************** * Constants @@ -34,6 +35,12 @@ * Macros ****************************************************************/ +#ifdef HUF_DISABLE_FAST_DECODE +# define HUF_ENABLE_FAST_DECODE 0 +#else +# define HUF_ENABLE_FAST_DECODE 1 +#endif + /* These two optional macros force the use one way or another of the two * Huffman decompression implementations. You can't force in both directions * at the same time. @@ -43,27 +50,25 @@ #error "Cannot force the use of the X1 and X2 decoders at the same time!" #endif -#if ZSTD_ENABLE_ASM_X86_64_BMI2 && DYNAMIC_BMI2 -# define HUF_ASM_X86_64_BMI2_ATTRS BMI2_TARGET_ATTRIBUTE +/* When DYNAMIC_BMI2 is enabled, fast decoders are only called when bmi2 is + * supported at runtime, so we can add the BMI2 target attribute. + * When it is disabled, we will still get BMI2 if it is enabled statically. + */ +#if DYNAMIC_BMI2 +# define HUF_FAST_BMI2_ATTRS BMI2_TARGET_ATTRIBUTE #else -# define HUF_ASM_X86_64_BMI2_ATTRS +# define HUF_FAST_BMI2_ATTRS #endif #define HUF_EXTERN_C #define HUF_ASM_DECL HUF_EXTERN_C -#if DYNAMIC_BMI2 || (ZSTD_ENABLE_ASM_X86_64_BMI2 && defined(__BMI2__)) +#if DYNAMIC_BMI2 # define HUF_NEED_BMI2_FUNCTION 1 #else # define HUF_NEED_BMI2_FUNCTION 0 #endif -#if !(ZSTD_ENABLE_ASM_X86_64_BMI2 && defined(__BMI2__)) -# define HUF_NEED_DEFAULT_FUNCTION 1 -#else -# define HUF_NEED_DEFAULT_FUNCTION 0 -#endif - /* ************************************************************** * Error Management ****************************************************************/ @@ -80,6 +85,11 @@ /* ************************************************************** * BMI2 Variant Wrappers ****************************************************************/ +typedef size_t (*HUF_DecompressUsingDTableFn)(void *dst, size_t dstSize, + const void *cSrc, + size_t cSrcSize, + const HUF_DTable *DTable); + #if DYNAMIC_BMI2 #define HUF_DGEN(fn) \ @@ -101,9 +111,9 @@ } \ \ static size_t fn(void* dst, size_t dstSize, void const* cSrc, \ - size_t cSrcSize, HUF_DTable const* DTable, int bmi2) \ + size_t cSrcSize, HUF_DTable const* DTable, int flags) \ { \ - if (bmi2) { \ + if (flags & HUF_flags_bmi2) { \ return fn##_bmi2(dst, dstSize, cSrc, cSrcSize, DTable); \ } \ return fn##_default(dst, dstSize, cSrc, cSrcSize, DTable); \ @@ -113,9 +123,9 @@ #define HUF_DGEN(fn) \ static size_t fn(void* dst, size_t dstSize, void const* cSrc, \ - size_t cSrcSize, HUF_DTable const* DTable, int bmi2) \ + size_t cSrcSize, HUF_DTable const* DTable, int flags) \ { \ - (void)bmi2; \ + (void)flags; \ return fn##_body(dst, dstSize, cSrc, cSrcSize, DTable); \ } @@ -134,43 +144,66 @@ static DTableDesc HUF_getDTableDesc(const HUF_DTable* table) return dtd; } -#if ZSTD_ENABLE_ASM_X86_64_BMI2 - -static size_t HUF_initDStream(BYTE const* ip) { +static size_t HUF_initFastDStream(BYTE const* ip) { BYTE const lastByte = ip[7]; - size_t const bitsConsumed = lastByte ? 8 - BIT_highbit32(lastByte) : 0; + size_t const bitsConsumed = lastByte ? 8 - ZSTD_highbit32(lastByte) : 0; size_t const value = MEM_readLEST(ip) | 1; assert(bitsConsumed <= 8); + assert(sizeof(size_t) == 8); return value << bitsConsumed; } + + +/* + * The input/output arguments to the Huffman fast decoding loop: + * + * ip [in/out] - The input pointers, must be updated to reflect what is consumed. + * op [in/out] - The output pointers, must be updated to reflect what is written. + * bits [in/out] - The bitstream containers, must be updated to reflect the current state. + * dt [in] - The decoding table. + * ilowest [in] - The beginning of the valid range of the input. Decoders may read + * down to this pointer. It may be below iend[0]. + * oend [in] - The end of the output stream. op[3] must not cross oend. + * iend [in] - The end of each input stream. ip[i] may cross iend[i], + * as long as it is above ilowest, but that indicates corruption. + */ typedef struct { BYTE const* ip[4]; BYTE* op[4]; U64 bits[4]; void const* dt; - BYTE const* ilimit; + BYTE const* ilowest; BYTE* oend; BYTE const* iend[4]; -} HUF_DecompressAsmArgs; +} HUF_DecompressFastArgs; + +typedef void (*HUF_DecompressFastLoopFn)(HUF_DecompressFastArgs*); /* - * Initializes args for the asm decoding loop. - * @returns 0 on success - * 1 if the fallback implementation should be used. + * Initializes args for the fast decoding loop. + * @returns 1 on success + * 0 if the fallback implementation should be used. * Or an error code on failure. */ -static size_t HUF_DecompressAsmArgs_init(HUF_DecompressAsmArgs* args, void* dst, size_t dstSize, void const* src, size_t srcSize, const HUF_DTable* DTable) +static size_t HUF_DecompressFastArgs_init(HUF_DecompressFastArgs* args, void* dst, size_t dstSize, void const* src, size_t srcSize, const HUF_DTable* DTable) { void const* dt = DTable + 1; U32 const dtLog = HUF_getDTableDesc(DTable).tableLog; - const BYTE* const ilimit = (const BYTE*)src + 6 + 8; + const BYTE* const istart = (const BYTE*)src; - BYTE* const oend = (BYTE*)dst + dstSize; + BYTE* const oend = ZSTD_maybeNullPtrAdd((BYTE*)dst, dstSize); - /* The following condition is false on x32 platform, - * but HUF_asm is not compatible with this ABI */ - if (!(MEM_isLittleEndian() && !MEM_32bits())) return 1; + /* The fast decoding loop assumes 64-bit little-endian. + * This condition is false on x32. + */ + if (!MEM_isLittleEndian() || MEM_32bits()) + return 0; + + /* Avoid nullptr addition */ + if (dstSize == 0) + return 0; + assert(dst != NULL); /* strict minimum : jump table + 1 byte per stream */ if (srcSize < 10) @@ -181,11 +214,10 @@ static size_t HUF_DecompressAsmArgs_init(HUF_DecompressAsmArgs* args, void* dst, * On small inputs we don't have enough data to trigger the fast loop, so use the old decoder. */ if (dtLog != HUF_DECODER_FAST_TABLELOG) - return 1; + return 0; /* Read the jump table. */ { - const BYTE* const istart = (const BYTE*)src; size_t const length1 = MEM_readLE16(istart); size_t const length2 = MEM_readLE16(istart+2); size_t const length3 = MEM_readLE16(istart+4); @@ -195,13 +227,11 @@ static size_t HUF_DecompressAsmArgs_init(HUF_DecompressAsmArgs* args, void* dst, args->iend[2] = args->iend[1] + length2; args->iend[3] = args->iend[2] + length3; - /* HUF_initDStream() requires this, and this small of an input + /* HUF_initFastDStream() requires this, and this small of an input * won't benefit from the ASM loop anyways. - * length1 must be >= 16 so that ip[0] >= ilimit before the loop - * starts. */ - if (length1 < 16 || length2 < 8 || length3 < 8 || length4 < 8) - return 1; + if (length1 < 8 || length2 < 8 || length3 < 8 || length4 < 8) + return 0; if (length4 > srcSize) return ERROR(corruption_detected); /* overflow */ } /* ip[] contains the position that is currently loaded into bits[]. */ @@ -218,7 +248,7 @@ static size_t HUF_DecompressAsmArgs_init(HUF_DecompressAsmArgs* args, void* dst, /* No point to call the ASM loop for tiny outputs. */ if (args->op[3] >= oend) - return 1; + return 0; /* bits[] is the bit container. * It is read from the MSB down to the LSB. @@ -227,24 +257,25 @@ static size_t HUF_DecompressAsmArgs_init(HUF_DecompressAsmArgs* args, void* dst, * set, so that CountTrailingZeros(bits[]) can be used * to count how many bits we've consumed. */ - args->bits[0] = HUF_initDStream(args->ip[0]); - args->bits[1] = HUF_initDStream(args->ip[1]); - args->bits[2] = HUF_initDStream(args->ip[2]); - args->bits[3] = HUF_initDStream(args->ip[3]); - - /* If ip[] >= ilimit, it is guaranteed to be safe to - * reload bits[]. It may be beyond its section, but is - * guaranteed to be valid (>= istart). - */ - args->ilimit = ilimit; + args->bits[0] = HUF_initFastDStream(args->ip[0]); + args->bits[1] = HUF_initFastDStream(args->ip[1]); + args->bits[2] = HUF_initFastDStream(args->ip[2]); + args->bits[3] = HUF_initFastDStream(args->ip[3]); + + /* The decoders must be sure to never read beyond ilowest. + * This is lower than iend[0], but allowing decoders to read + * down to ilowest can allow an extra iteration or two in the + * fast loop. + */ + args->ilowest = istart; args->oend = oend; args->dt = dt; - return 0; + return 1; } -static size_t HUF_initRemainingDStream(BIT_DStream_t* bit, HUF_DecompressAsmArgs const* args, int stream, BYTE* segmentEnd) +static size_t HUF_initRemainingDStream(BIT_DStream_t* bit, HUF_DecompressFastArgs const* args, int stream, BYTE* segmentEnd) { /* Validate that we haven't overwritten. */ if (args->op[stream] > segmentEnd) @@ -258,15 +289,33 @@ static size_t HUF_initRemainingDStream(BIT_DStream_t* bit, HUF_DecompressAsmArgs return ERROR(corruption_detected); /* Construct the BIT_DStream_t. */ - bit->bitContainer = MEM_readLE64(args->ip[stream]); - bit->bitsConsumed = ZSTD_countTrailingZeros((size_t)args->bits[stream]); - bit->start = (const char*)args->iend[0]; + assert(sizeof(size_t) == 8); + bit->bitContainer = MEM_readLEST(args->ip[stream]); + bit->bitsConsumed = ZSTD_countTrailingZeros64(args->bits[stream]); + bit->start = (const char*)args->ilowest; bit->limitPtr = bit->start + sizeof(size_t); bit->ptr = (const char*)args->ip[stream]; return 0; } -#endif + +/* Calls X(N) for each stream 0, 1, 2, 3. */ +#define HUF_4X_FOR_EACH_STREAM(X) \ + do { \ + X(0); \ + X(1); \ + X(2); \ + X(3); \ + } while (0) + +/* Calls X(N, var) for each stream 0, 1, 2, 3. */ +#define HUF_4X_FOR_EACH_STREAM_WITH_VAR(X, var) \ + do { \ + X(0, (var)); \ + X(1, (var)); \ + X(2, (var)); \ + X(3, (var)); \ + } while (0) #ifndef HUF_FORCE_DECOMPRESS_X2 @@ -283,10 +332,11 @@ typedef struct { BYTE nbBits; BYTE byte; } HUF_DEltX1; /* single-symbol decodi static U64 HUF_DEltX1_set4(BYTE symbol, BYTE nbBits) { U64 D4; if (MEM_isLittleEndian()) { - D4 = (symbol << 8) + nbBits; + D4 = (U64)((symbol << 8) + nbBits); } else { - D4 = symbol + (nbBits << 8); + D4 = (U64)(symbol + (nbBits << 8)); } + assert(D4 < (1U << 16)); D4 *= 0x0001000100010001ULL; return D4; } @@ -329,13 +379,7 @@ typedef struct { BYTE huffWeight[HUF_SYMBOLVALUE_MAX + 1]; } HUF_ReadDTableX1_Workspace; - -size_t HUF_readDTableX1_wksp(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize) -{ - return HUF_readDTableX1_wksp_bmi2(DTable, src, srcSize, workSpace, wkspSize, /* bmi2 */ 0); -} - -size_t HUF_readDTableX1_wksp_bmi2(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize, int bmi2) +size_t HUF_readDTableX1_wksp(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize, int flags) { U32 tableLog = 0; U32 nbSymbols = 0; @@ -350,7 +394,7 @@ size_t HUF_readDTableX1_wksp_bmi2(HUF_DTable* DTable, const void* src, size_t sr DEBUG_STATIC_ASSERT(sizeof(DTableDesc) == sizeof(HUF_DTable)); /* ZSTD_memset(huffWeight, 0, sizeof(huffWeight)); */ /* is not necessary, even though some analyzer complain ... */ - iSize = HUF_readStats_wksp(wksp->huffWeight, HUF_SYMBOLVALUE_MAX + 1, wksp->rankVal, &nbSymbols, &tableLog, src, srcSize, wksp->statsWksp, sizeof(wksp->statsWksp), bmi2); + iSize = HUF_readStats_wksp(wksp->huffWeight, HUF_SYMBOLVALUE_MAX + 1, wksp->rankVal, &nbSymbols, &tableLog, src, srcSize, wksp->statsWksp, sizeof(wksp->statsWksp), flags); if (HUF_isError(iSize)) return iSize; @@ -377,9 +421,8 @@ size_t HUF_readDTableX1_wksp_bmi2(HUF_DTable* DTable, const void* src, size_t sr * rankStart[0] is not filled because there are no entries in the table for * weight 0. */ - { - int n; - int nextRankStart = 0; + { int n; + U32 nextRankStart = 0; int const unroll = 4; int const nLimit = (int)nbSymbols - unroll + 1; for (n=0; n<(int)tableLog+1; n++) { @@ -406,10 +449,9 @@ size_t HUF_readDTableX1_wksp_bmi2(HUF_DTable* DTable, const void* src, size_t sr * We can switch based on the length to a different inner loop which is * optimized for that particular case. */ - { - U32 w; - int symbol=wksp->rankVal[0]; - int rankStart=0; + { U32 w; + int symbol = wksp->rankVal[0]; + int rankStart = 0; for (w=1; wrankVal[w]; int const length = (1 << w) >> 1; @@ -483,15 +525,19 @@ HUF_decodeSymbolX1(BIT_DStream_t* Dstream, const HUF_DEltX1* dt, const U32 dtLog } #define HUF_DECODE_SYMBOLX1_0(ptr, DStreamPtr) \ - *ptr++ = HUF_decodeSymbolX1(DStreamPtr, dt, dtLog) + do { *ptr++ = HUF_decodeSymbolX1(DStreamPtr, dt, dtLog); } while (0) -#define HUF_DECODE_SYMBOLX1_1(ptr, DStreamPtr) \ - if (MEM_64bits() || (HUF_TABLELOG_MAX<=12)) \ - HUF_DECODE_SYMBOLX1_0(ptr, DStreamPtr) +#define HUF_DECODE_SYMBOLX1_1(ptr, DStreamPtr) \ + do { \ + if (MEM_64bits() || (HUF_TABLELOG_MAX<=12)) \ + HUF_DECODE_SYMBOLX1_0(ptr, DStreamPtr); \ + } while (0) -#define HUF_DECODE_SYMBOLX1_2(ptr, DStreamPtr) \ - if (MEM_64bits()) \ - HUF_DECODE_SYMBOLX1_0(ptr, DStreamPtr) +#define HUF_DECODE_SYMBOLX1_2(ptr, DStreamPtr) \ + do { \ + if (MEM_64bits()) \ + HUF_DECODE_SYMBOLX1_0(ptr, DStreamPtr); \ + } while (0) HINT_INLINE size_t HUF_decodeStreamX1(BYTE* p, BIT_DStream_t* const bitDPtr, BYTE* const pEnd, const HUF_DEltX1* const dt, const U32 dtLog) @@ -519,7 +565,7 @@ HUF_decodeStreamX1(BYTE* p, BIT_DStream_t* const bitDPtr, BYTE* const pEnd, cons while (p < pEnd) HUF_DECODE_SYMBOLX1_0(p, bitDPtr); - return pEnd-pStart; + return (size_t)(pEnd-pStart); } FORCE_INLINE_TEMPLATE size_t @@ -529,7 +575,7 @@ HUF_decompress1X1_usingDTable_internal_body( const HUF_DTable* DTable) { BYTE* op = (BYTE*)dst; - BYTE* const oend = op + dstSize; + BYTE* const oend = ZSTD_maybeNullPtrAdd(op, dstSize); const void* dtPtr = DTable + 1; const HUF_DEltX1* const dt = (const HUF_DEltX1*)dtPtr; BIT_DStream_t bitD; @@ -545,6 +591,10 @@ HUF_decompress1X1_usingDTable_internal_body( return dstSize; } +/* HUF_decompress4X1_usingDTable_internal_body(): + * Conditions : + * @dstSize >= 6 + */ FORCE_INLINE_TEMPLATE size_t HUF_decompress4X1_usingDTable_internal_body( void* dst, size_t dstSize, @@ -553,6 +603,7 @@ HUF_decompress4X1_usingDTable_internal_body( { /* Check */ if (cSrcSize < 10) return ERROR(corruption_detected); /* strict minimum : jump table + 1 byte per stream */ + if (dstSize < 6) return ERROR(corruption_detected); /* stream 4-split doesn't work */ { const BYTE* const istart = (const BYTE*) cSrc; BYTE* const ostart = (BYTE*) dst; @@ -588,6 +639,7 @@ HUF_decompress4X1_usingDTable_internal_body( if (length4 > cSrcSize) return ERROR(corruption_detected); /* overflow */ if (opStart4 > oend) return ERROR(corruption_detected); /* overflow */ + assert(dstSize >= 6); /* validated above */ CHECK_F( BIT_initDStream(&bitD1, istart1, length1) ); CHECK_F( BIT_initDStream(&bitD2, istart2, length2) ); CHECK_F( BIT_initDStream(&bitD3, istart3, length3) ); @@ -650,52 +702,173 @@ size_t HUF_decompress4X1_usingDTable_internal_bmi2(void* dst, size_t dstSize, vo } #endif -#if HUF_NEED_DEFAULT_FUNCTION static size_t HUF_decompress4X1_usingDTable_internal_default(void* dst, size_t dstSize, void const* cSrc, size_t cSrcSize, HUF_DTable const* DTable) { return HUF_decompress4X1_usingDTable_internal_body(dst, dstSize, cSrc, cSrcSize, DTable); } -#endif #if ZSTD_ENABLE_ASM_X86_64_BMI2 -HUF_ASM_DECL void HUF_decompress4X1_usingDTable_internal_bmi2_asm_loop(HUF_DecompressAsmArgs* args) ZSTDLIB_HIDDEN; +HUF_ASM_DECL void HUF_decompress4X1_usingDTable_internal_fast_asm_loop(HUF_DecompressFastArgs* args) ZSTDLIB_HIDDEN; + +#endif + +static HUF_FAST_BMI2_ATTRS +void HUF_decompress4X1_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs* args) +{ + U64 bits[4]; + BYTE const* ip[4]; + BYTE* op[4]; + U16 const* const dtable = (U16 const*)args->dt; + BYTE* const oend = args->oend; + BYTE const* const ilowest = args->ilowest; + + /* Copy the arguments to local variables */ + ZSTD_memcpy(&bits, &args->bits, sizeof(bits)); + ZSTD_memcpy((void*)(&ip), &args->ip, sizeof(ip)); + ZSTD_memcpy(&op, &args->op, sizeof(op)); + + assert(MEM_isLittleEndian()); + assert(!MEM_32bits()); + + for (;;) { + BYTE* olimit; + int stream; + + /* Assert loop preconditions */ +#ifndef NDEBUG + for (stream = 0; stream < 4; ++stream) { + assert(op[stream] <= (stream == 3 ? oend : op[stream + 1])); + assert(ip[stream] >= ilowest); + } +#endif + /* Compute olimit */ + { + /* Each iteration produces 5 output symbols per stream */ + size_t const oiters = (size_t)(oend - op[3]) / 5; + /* Each iteration consumes up to 11 bits * 5 = 55 bits < 7 bytes + * per stream. + */ + size_t const iiters = (size_t)(ip[0] - ilowest) / 7; + /* We can safely run iters iterations before running bounds checks */ + size_t const iters = MIN(oiters, iiters); + size_t const symbols = iters * 5; + + /* We can simply check that op[3] < olimit, instead of checking all + * of our bounds, since we can't hit the other bounds until we've run + * iters iterations, which only happens when op[3] == olimit. + */ + olimit = op[3] + symbols; + + /* Exit fast decoding loop once we reach the end. */ + if (op[3] == olimit) + break; + + /* Exit the decoding loop if any input pointer has crossed the + * previous one. This indicates corruption, and a precondition + * to our loop is that ip[i] >= ip[0]. + */ + for (stream = 1; stream < 4; ++stream) { + if (ip[stream] < ip[stream - 1]) + goto _out; + } + } + +#ifndef NDEBUG + for (stream = 1; stream < 4; ++stream) { + assert(ip[stream] >= ip[stream - 1]); + } +#endif + +#define HUF_4X1_DECODE_SYMBOL(_stream, _symbol) \ + do { \ + int const index = (int)(bits[(_stream)] >> 53); \ + int const entry = (int)dtable[index]; \ + bits[(_stream)] <<= (entry & 0x3F); \ + op[(_stream)][(_symbol)] = (BYTE)((entry >> 8) & 0xFF); \ + } while (0) + +#define HUF_4X1_RELOAD_STREAM(_stream) \ + do { \ + int const ctz = ZSTD_countTrailingZeros64(bits[(_stream)]); \ + int const nbBits = ctz & 7; \ + int const nbBytes = ctz >> 3; \ + op[(_stream)] += 5; \ + ip[(_stream)] -= nbBytes; \ + bits[(_stream)] = MEM_read64(ip[(_stream)]) | 1; \ + bits[(_stream)] <<= nbBits; \ + } while (0) + + /* Manually unroll the loop because compilers don't consistently + * unroll the inner loops, which destroys performance. + */ + do { + /* Decode 5 symbols in each of the 4 streams */ + HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X1_DECODE_SYMBOL, 0); + HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X1_DECODE_SYMBOL, 1); + HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X1_DECODE_SYMBOL, 2); + HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X1_DECODE_SYMBOL, 3); + HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X1_DECODE_SYMBOL, 4); + + /* Reload each of the 4 the bitstreams */ + HUF_4X_FOR_EACH_STREAM(HUF_4X1_RELOAD_STREAM); + } while (op[3] < olimit); + +#undef HUF_4X1_DECODE_SYMBOL +#undef HUF_4X1_RELOAD_STREAM + } -static HUF_ASM_X86_64_BMI2_ATTRS +_out: + + /* Save the final values of each of the state variables back to args. */ + ZSTD_memcpy(&args->bits, &bits, sizeof(bits)); + ZSTD_memcpy((void*)(&args->ip), &ip, sizeof(ip)); + ZSTD_memcpy(&args->op, &op, sizeof(op)); +} + +/* + * @returns @p dstSize on success (>= 6) + * 0 if the fallback implementation should be used + * An error if an error occurred + */ +static HUF_FAST_BMI2_ATTRS size_t -HUF_decompress4X1_usingDTable_internal_bmi2_asm( +HUF_decompress4X1_usingDTable_internal_fast( void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, - const HUF_DTable* DTable) + const HUF_DTable* DTable, + HUF_DecompressFastLoopFn loopFn) { void const* dt = DTable + 1; - const BYTE* const iend = (const BYTE*)cSrc + 6; - BYTE* const oend = (BYTE*)dst + dstSize; - HUF_DecompressAsmArgs args; - { - size_t const ret = HUF_DecompressAsmArgs_init(&args, dst, dstSize, cSrc, cSrcSize, DTable); - FORWARD_IF_ERROR(ret, "Failed to init asm args"); - if (ret != 0) - return HUF_decompress4X1_usingDTable_internal_bmi2(dst, dstSize, cSrc, cSrcSize, DTable); + BYTE const* const ilowest = (BYTE const*)cSrc; + BYTE* const oend = ZSTD_maybeNullPtrAdd((BYTE*)dst, dstSize); + HUF_DecompressFastArgs args; + { size_t const ret = HUF_DecompressFastArgs_init(&args, dst, dstSize, cSrc, cSrcSize, DTable); + FORWARD_IF_ERROR(ret, "Failed to init fast loop args"); + if (ret == 0) + return 0; } - assert(args.ip[0] >= args.ilimit); - HUF_decompress4X1_usingDTable_internal_bmi2_asm_loop(&args); + assert(args.ip[0] >= args.ilowest); + loopFn(&args); - /* Our loop guarantees that ip[] >= ilimit and that we haven't + /* Our loop guarantees that ip[] >= ilowest and that we haven't * overwritten any op[]. */ - assert(args.ip[0] >= iend); - assert(args.ip[1] >= iend); - assert(args.ip[2] >= iend); - assert(args.ip[3] >= iend); + assert(args.ip[0] >= ilowest); + assert(args.ip[0] >= ilowest); + assert(args.ip[1] >= ilowest); + assert(args.ip[2] >= ilowest); + assert(args.ip[3] >= ilowest); assert(args.op[3] <= oend); - (void)iend; + + assert(ilowest == args.ilowest); + assert(ilowest + 6 == args.iend[0]); + (void)ilowest; /* finish bit streams one by one. */ - { - size_t const segmentSize = (dstSize+3) / 4; + { size_t const segmentSize = (dstSize+3) / 4; BYTE* segmentEnd = (BYTE*)dst; int i; for (i = 0; i < 4; ++i) { @@ -712,97 +885,59 @@ HUF_decompress4X1_usingDTable_internal_bmi2_asm( } /* decoded size */ + assert(dstSize != 0); return dstSize; } -#endif /* ZSTD_ENABLE_ASM_X86_64_BMI2 */ - -typedef size_t (*HUF_decompress_usingDTable_t)(void *dst, size_t dstSize, - const void *cSrc, - size_t cSrcSize, - const HUF_DTable *DTable); HUF_DGEN(HUF_decompress1X1_usingDTable_internal) static size_t HUF_decompress4X1_usingDTable_internal(void* dst, size_t dstSize, void const* cSrc, - size_t cSrcSize, HUF_DTable const* DTable, int bmi2) + size_t cSrcSize, HUF_DTable const* DTable, int flags) { + HUF_DecompressUsingDTableFn fallbackFn = HUF_decompress4X1_usingDTable_internal_default; + HUF_DecompressFastLoopFn loopFn = HUF_decompress4X1_usingDTable_internal_fast_c_loop; + #if DYNAMIC_BMI2 - if (bmi2) { + if (flags & HUF_flags_bmi2) { + fallbackFn = HUF_decompress4X1_usingDTable_internal_bmi2; # if ZSTD_ENABLE_ASM_X86_64_BMI2 - return HUF_decompress4X1_usingDTable_internal_bmi2_asm(dst, dstSize, cSrc, cSrcSize, DTable); -# else - return HUF_decompress4X1_usingDTable_internal_bmi2(dst, dstSize, cSrc, cSrcSize, DTable); + if (!(flags & HUF_flags_disableAsm)) { + loopFn = HUF_decompress4X1_usingDTable_internal_fast_asm_loop; + } # endif + } else { + return fallbackFn(dst, dstSize, cSrc, cSrcSize, DTable); } -#else - (void)bmi2; #endif #if ZSTD_ENABLE_ASM_X86_64_BMI2 && defined(__BMI2__) - return HUF_decompress4X1_usingDTable_internal_bmi2_asm(dst, dstSize, cSrc, cSrcSize, DTable); -#else - return HUF_decompress4X1_usingDTable_internal_default(dst, dstSize, cSrc, cSrcSize, DTable); + if (!(flags & HUF_flags_disableAsm)) { + loopFn = HUF_decompress4X1_usingDTable_internal_fast_asm_loop; + } #endif -} - - -size_t HUF_decompress1X1_usingDTable( - void* dst, size_t dstSize, - const void* cSrc, size_t cSrcSize, - const HUF_DTable* DTable) -{ - DTableDesc dtd = HUF_getDTableDesc(DTable); - if (dtd.tableType != 0) return ERROR(GENERIC); - return HUF_decompress1X1_usingDTable_internal(dst, dstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); -} -size_t HUF_decompress1X1_DCtx_wksp(HUF_DTable* DCtx, void* dst, size_t dstSize, - const void* cSrc, size_t cSrcSize, - void* workSpace, size_t wkspSize) -{ - const BYTE* ip = (const BYTE*) cSrc; - - size_t const hSize = HUF_readDTableX1_wksp(DCtx, cSrc, cSrcSize, workSpace, wkspSize); - if (HUF_isError(hSize)) return hSize; - if (hSize >= cSrcSize) return ERROR(srcSize_wrong); - ip += hSize; cSrcSize -= hSize; - - return HUF_decompress1X1_usingDTable_internal(dst, dstSize, ip, cSrcSize, DCtx, /* bmi2 */ 0); -} - - -size_t HUF_decompress4X1_usingDTable( - void* dst, size_t dstSize, - const void* cSrc, size_t cSrcSize, - const HUF_DTable* DTable) -{ - DTableDesc dtd = HUF_getDTableDesc(DTable); - if (dtd.tableType != 0) return ERROR(GENERIC); - return HUF_decompress4X1_usingDTable_internal(dst, dstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); + if (HUF_ENABLE_FAST_DECODE && !(flags & HUF_flags_disableFast)) { + size_t const ret = HUF_decompress4X1_usingDTable_internal_fast(dst, dstSize, cSrc, cSrcSize, DTable, loopFn); + if (ret != 0) + return ret; + } + return fallbackFn(dst, dstSize, cSrc, cSrcSize, DTable); } -static size_t HUF_decompress4X1_DCtx_wksp_bmi2(HUF_DTable* dctx, void* dst, size_t dstSize, +static size_t HUF_decompress4X1_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, - void* workSpace, size_t wkspSize, int bmi2) + void* workSpace, size_t wkspSize, int flags) { const BYTE* ip = (const BYTE*) cSrc; - size_t const hSize = HUF_readDTableX1_wksp_bmi2(dctx, cSrc, cSrcSize, workSpace, wkspSize, bmi2); + size_t const hSize = HUF_readDTableX1_wksp(dctx, cSrc, cSrcSize, workSpace, wkspSize, flags); if (HUF_isError(hSize)) return hSize; if (hSize >= cSrcSize) return ERROR(srcSize_wrong); ip += hSize; cSrcSize -= hSize; - return HUF_decompress4X1_usingDTable_internal(dst, dstSize, ip, cSrcSize, dctx, bmi2); -} - -size_t HUF_decompress4X1_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, - const void* cSrc, size_t cSrcSize, - void* workSpace, size_t wkspSize) -{ - return HUF_decompress4X1_DCtx_wksp_bmi2(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, 0); + return HUF_decompress4X1_usingDTable_internal(dst, dstSize, ip, cSrcSize, dctx, flags); } - #endif /* HUF_FORCE_DECOMPRESS_X2 */ @@ -985,7 +1120,7 @@ static void HUF_fillDTableX2Level2(HUF_DEltX2* DTable, U32 targetLog, const U32 static void HUF_fillDTableX2(HUF_DEltX2* DTable, const U32 targetLog, const sortedSymbol_t* sortedList, - const U32* rankStart, rankValCol_t *rankValOrigin, const U32 maxWeight, + const U32* rankStart, rankValCol_t* rankValOrigin, const U32 maxWeight, const U32 nbBitsBaseline) { U32* const rankVal = rankValOrigin[0]; @@ -1040,14 +1175,7 @@ typedef struct { size_t HUF_readDTableX2_wksp(HUF_DTable* DTable, const void* src, size_t srcSize, - void* workSpace, size_t wkspSize) -{ - return HUF_readDTableX2_wksp_bmi2(DTable, src, srcSize, workSpace, wkspSize, /* bmi2 */ 0); -} - -size_t HUF_readDTableX2_wksp_bmi2(HUF_DTable* DTable, - const void* src, size_t srcSize, - void* workSpace, size_t wkspSize, int bmi2) + void* workSpace, size_t wkspSize, int flags) { U32 tableLog, maxW, nbSymbols; DTableDesc dtd = HUF_getDTableDesc(DTable); @@ -1069,7 +1197,7 @@ size_t HUF_readDTableX2_wksp_bmi2(HUF_DTable* DTable, if (maxTableLog > HUF_TABLELOG_MAX) return ERROR(tableLog_tooLarge); /* ZSTD_memset(weightList, 0, sizeof(weightList)); */ /* is not necessary, even though some analyzer complain ... */ - iSize = HUF_readStats_wksp(wksp->weightList, HUF_SYMBOLVALUE_MAX + 1, wksp->rankStats, &nbSymbols, &tableLog, src, srcSize, wksp->calleeWksp, sizeof(wksp->calleeWksp), bmi2); + iSize = HUF_readStats_wksp(wksp->weightList, HUF_SYMBOLVALUE_MAX + 1, wksp->rankStats, &nbSymbols, &tableLog, src, srcSize, wksp->calleeWksp, sizeof(wksp->calleeWksp), flags); if (HUF_isError(iSize)) return iSize; /* check result */ @@ -1159,15 +1287,19 @@ HUF_decodeLastSymbolX2(void* op, BIT_DStream_t* DStream, const HUF_DEltX2* dt, c } #define HUF_DECODE_SYMBOLX2_0(ptr, DStreamPtr) \ - ptr += HUF_decodeSymbolX2(ptr, DStreamPtr, dt, dtLog) + do { ptr += HUF_decodeSymbolX2(ptr, DStreamPtr, dt, dtLog); } while (0) -#define HUF_DECODE_SYMBOLX2_1(ptr, DStreamPtr) \ - if (MEM_64bits() || (HUF_TABLELOG_MAX<=12)) \ - ptr += HUF_decodeSymbolX2(ptr, DStreamPtr, dt, dtLog) +#define HUF_DECODE_SYMBOLX2_1(ptr, DStreamPtr) \ + do { \ + if (MEM_64bits() || (HUF_TABLELOG_MAX<=12)) \ + ptr += HUF_decodeSymbolX2(ptr, DStreamPtr, dt, dtLog); \ + } while (0) -#define HUF_DECODE_SYMBOLX2_2(ptr, DStreamPtr) \ - if (MEM_64bits()) \ - ptr += HUF_decodeSymbolX2(ptr, DStreamPtr, dt, dtLog) +#define HUF_DECODE_SYMBOLX2_2(ptr, DStreamPtr) \ + do { \ + if (MEM_64bits()) \ + ptr += HUF_decodeSymbolX2(ptr, DStreamPtr, dt, dtLog); \ + } while (0) HINT_INLINE size_t HUF_decodeStreamX2(BYTE* p, BIT_DStream_t* bitDPtr, BYTE* const pEnd, @@ -1227,7 +1359,7 @@ HUF_decompress1X2_usingDTable_internal_body( /* decode */ { BYTE* const ostart = (BYTE*) dst; - BYTE* const oend = ostart + dstSize; + BYTE* const oend = ZSTD_maybeNullPtrAdd(ostart, dstSize); const void* const dtPtr = DTable+1; /* force compiler to not use strict-aliasing */ const HUF_DEltX2* const dt = (const HUF_DEltX2*)dtPtr; DTableDesc const dtd = HUF_getDTableDesc(DTable); @@ -1240,6 +1372,11 @@ HUF_decompress1X2_usingDTable_internal_body( /* decoded size */ return dstSize; } + +/* HUF_decompress4X2_usingDTable_internal_body(): + * Conditions: + * @dstSize >= 6 + */ FORCE_INLINE_TEMPLATE size_t HUF_decompress4X2_usingDTable_internal_body( void* dst, size_t dstSize, @@ -1247,6 +1384,7 @@ HUF_decompress4X2_usingDTable_internal_body( const HUF_DTable* DTable) { if (cSrcSize < 10) return ERROR(corruption_detected); /* strict minimum : jump table + 1 byte per stream */ + if (dstSize < 6) return ERROR(corruption_detected); /* stream 4-split doesn't work */ { const BYTE* const istart = (const BYTE*) cSrc; BYTE* const ostart = (BYTE*) dst; @@ -1280,8 +1418,9 @@ HUF_decompress4X2_usingDTable_internal_body( DTableDesc const dtd = HUF_getDTableDesc(DTable); U32 const dtLog = dtd.tableLog; - if (length4 > cSrcSize) return ERROR(corruption_detected); /* overflow */ - if (opStart4 > oend) return ERROR(corruption_detected); /* overflow */ + if (length4 > cSrcSize) return ERROR(corruption_detected); /* overflow */ + if (opStart4 > oend) return ERROR(corruption_detected); /* overflow */ + assert(dstSize >= 6 /* validated above */); CHECK_F( BIT_initDStream(&bitD1, istart1, length1) ); CHECK_F( BIT_initDStream(&bitD2, istart2, length2) ); CHECK_F( BIT_initDStream(&bitD3, istart3, length3) ); @@ -1366,44 +1505,191 @@ size_t HUF_decompress4X2_usingDTable_internal_bmi2(void* dst, size_t dstSize, vo } #endif -#if HUF_NEED_DEFAULT_FUNCTION static size_t HUF_decompress4X2_usingDTable_internal_default(void* dst, size_t dstSize, void const* cSrc, size_t cSrcSize, HUF_DTable const* DTable) { return HUF_decompress4X2_usingDTable_internal_body(dst, dstSize, cSrc, cSrcSize, DTable); } -#endif #if ZSTD_ENABLE_ASM_X86_64_BMI2 -HUF_ASM_DECL void HUF_decompress4X2_usingDTable_internal_bmi2_asm_loop(HUF_DecompressAsmArgs* args) ZSTDLIB_HIDDEN; +HUF_ASM_DECL void HUF_decompress4X2_usingDTable_internal_fast_asm_loop(HUF_DecompressFastArgs* args) ZSTDLIB_HIDDEN; + +#endif + +static HUF_FAST_BMI2_ATTRS +void HUF_decompress4X2_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs* args) +{ + U64 bits[4]; + BYTE const* ip[4]; + BYTE* op[4]; + BYTE* oend[4]; + HUF_DEltX2 const* const dtable = (HUF_DEltX2 const*)args->dt; + BYTE const* const ilowest = args->ilowest; + + /* Copy the arguments to local registers. */ + ZSTD_memcpy(&bits, &args->bits, sizeof(bits)); + ZSTD_memcpy((void*)(&ip), &args->ip, sizeof(ip)); + ZSTD_memcpy(&op, &args->op, sizeof(op)); + + oend[0] = op[1]; + oend[1] = op[2]; + oend[2] = op[3]; + oend[3] = args->oend; + + assert(MEM_isLittleEndian()); + assert(!MEM_32bits()); + + for (;;) { + BYTE* olimit; + int stream; + + /* Assert loop preconditions */ +#ifndef NDEBUG + for (stream = 0; stream < 4; ++stream) { + assert(op[stream] <= oend[stream]); + assert(ip[stream] >= ilowest); + } +#endif + /* Compute olimit */ + { + /* Each loop does 5 table lookups for each of the 4 streams. + * Each table lookup consumes up to 11 bits of input, and produces + * up to 2 bytes of output. + */ + /* We can consume up to 7 bytes of input per iteration per stream. + * We also know that each input pointer is >= ip[0]. So we can run + * iters loops before running out of input. + */ + size_t iters = (size_t)(ip[0] - ilowest) / 7; + /* Each iteration can produce up to 10 bytes of output per stream. + * Each output stream my advance at different rates. So take the + * minimum number of safe iterations among all the output streams. + */ + for (stream = 0; stream < 4; ++stream) { + size_t const oiters = (size_t)(oend[stream] - op[stream]) / 10; + iters = MIN(iters, oiters); + } + + /* Each iteration produces at least 5 output symbols. So until + * op[3] crosses olimit, we know we haven't executed iters + * iterations yet. This saves us maintaining an iters counter, + * at the expense of computing the remaining # of iterations + * more frequently. + */ + olimit = op[3] + (iters * 5); + + /* Exit the fast decoding loop once we reach the end. */ + if (op[3] == olimit) + break; + + /* Exit the decoding loop if any input pointer has crossed the + * previous one. This indicates corruption, and a precondition + * to our loop is that ip[i] >= ip[0]. + */ + for (stream = 1; stream < 4; ++stream) { + if (ip[stream] < ip[stream - 1]) + goto _out; + } + } + +#ifndef NDEBUG + for (stream = 1; stream < 4; ++stream) { + assert(ip[stream] >= ip[stream - 1]); + } +#endif -static HUF_ASM_X86_64_BMI2_ATTRS size_t -HUF_decompress4X2_usingDTable_internal_bmi2_asm( +#define HUF_4X2_DECODE_SYMBOL(_stream, _decode3) \ + do { \ + if ((_decode3) || (_stream) != 3) { \ + int const index = (int)(bits[(_stream)] >> 53); \ + HUF_DEltX2 const entry = dtable[index]; \ + MEM_write16(op[(_stream)], entry.sequence); \ + bits[(_stream)] <<= (entry.nbBits) & 0x3F; \ + op[(_stream)] += (entry.length); \ + } \ + } while (0) + +#define HUF_4X2_RELOAD_STREAM(_stream) \ + do { \ + HUF_4X2_DECODE_SYMBOL(3, 1); \ + { \ + int const ctz = ZSTD_countTrailingZeros64(bits[(_stream)]); \ + int const nbBits = ctz & 7; \ + int const nbBytes = ctz >> 3; \ + ip[(_stream)] -= nbBytes; \ + bits[(_stream)] = MEM_read64(ip[(_stream)]) | 1; \ + bits[(_stream)] <<= nbBits; \ + } \ + } while (0) + + /* Manually unroll the loop because compilers don't consistently + * unroll the inner loops, which destroys performance. + */ + do { + /* Decode 5 symbols from each of the first 3 streams. + * The final stream will be decoded during the reload phase + * to reduce register pressure. + */ + HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X2_DECODE_SYMBOL, 0); + HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X2_DECODE_SYMBOL, 0); + HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X2_DECODE_SYMBOL, 0); + HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X2_DECODE_SYMBOL, 0); + HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X2_DECODE_SYMBOL, 0); + + /* Decode one symbol from the final stream */ + HUF_4X2_DECODE_SYMBOL(3, 1); + + /* Decode 4 symbols from the final stream & reload bitstreams. + * The final stream is reloaded last, meaning that all 5 symbols + * are decoded from the final stream before it is reloaded. + */ + HUF_4X_FOR_EACH_STREAM(HUF_4X2_RELOAD_STREAM); + } while (op[3] < olimit); + } + +#undef HUF_4X2_DECODE_SYMBOL +#undef HUF_4X2_RELOAD_STREAM + +_out: + + /* Save the final values of each of the state variables back to args. */ + ZSTD_memcpy(&args->bits, &bits, sizeof(bits)); + ZSTD_memcpy((void*)(&args->ip), &ip, sizeof(ip)); + ZSTD_memcpy(&args->op, &op, sizeof(op)); +} + + +static HUF_FAST_BMI2_ATTRS size_t +HUF_decompress4X2_usingDTable_internal_fast( void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, - const HUF_DTable* DTable) { + const HUF_DTable* DTable, + HUF_DecompressFastLoopFn loopFn) { void const* dt = DTable + 1; - const BYTE* const iend = (const BYTE*)cSrc + 6; - BYTE* const oend = (BYTE*)dst + dstSize; - HUF_DecompressAsmArgs args; + const BYTE* const ilowest = (const BYTE*)cSrc; + BYTE* const oend = ZSTD_maybeNullPtrAdd((BYTE*)dst, dstSize); + HUF_DecompressFastArgs args; { - size_t const ret = HUF_DecompressAsmArgs_init(&args, dst, dstSize, cSrc, cSrcSize, DTable); + size_t const ret = HUF_DecompressFastArgs_init(&args, dst, dstSize, cSrc, cSrcSize, DTable); FORWARD_IF_ERROR(ret, "Failed to init asm args"); - if (ret != 0) - return HUF_decompress4X2_usingDTable_internal_bmi2(dst, dstSize, cSrc, cSrcSize, DTable); + if (ret == 0) + return 0; } - assert(args.ip[0] >= args.ilimit); - HUF_decompress4X2_usingDTable_internal_bmi2_asm_loop(&args); + assert(args.ip[0] >= args.ilowest); + loopFn(&args); /* note : op4 already verified within main loop */ - assert(args.ip[0] >= iend); - assert(args.ip[1] >= iend); - assert(args.ip[2] >= iend); - assert(args.ip[3] >= iend); + assert(args.ip[0] >= ilowest); + assert(args.ip[1] >= ilowest); + assert(args.ip[2] >= ilowest); + assert(args.ip[3] >= ilowest); assert(args.op[3] <= oend); - (void)iend; + + assert(ilowest == args.ilowest); + assert(ilowest + 6 == args.iend[0]); + (void)ilowest; /* finish bitStreams one by one */ { @@ -1426,91 +1712,72 @@ HUF_decompress4X2_usingDTable_internal_bmi2_asm( /* decoded size */ return dstSize; } -#endif /* ZSTD_ENABLE_ASM_X86_64_BMI2 */ static size_t HUF_decompress4X2_usingDTable_internal(void* dst, size_t dstSize, void const* cSrc, - size_t cSrcSize, HUF_DTable const* DTable, int bmi2) + size_t cSrcSize, HUF_DTable const* DTable, int flags) { + HUF_DecompressUsingDTableFn fallbackFn = HUF_decompress4X2_usingDTable_internal_default; + HUF_DecompressFastLoopFn loopFn = HUF_decompress4X2_usingDTable_internal_fast_c_loop; + #if DYNAMIC_BMI2 - if (bmi2) { + if (flags & HUF_flags_bmi2) { + fallbackFn = HUF_decompress4X2_usingDTable_internal_bmi2; # if ZSTD_ENABLE_ASM_X86_64_BMI2 - return HUF_decompress4X2_usingDTable_internal_bmi2_asm(dst, dstSize, cSrc, cSrcSize, DTable); -# else - return HUF_decompress4X2_usingDTable_internal_bmi2(dst, dstSize, cSrc, cSrcSize, DTable); + if (!(flags & HUF_flags_disableAsm)) { + loopFn = HUF_decompress4X2_usingDTable_internal_fast_asm_loop; + } # endif + } else { + return fallbackFn(dst, dstSize, cSrc, cSrcSize, DTable); } -#else - (void)bmi2; #endif #if ZSTD_ENABLE_ASM_X86_64_BMI2 && defined(__BMI2__) - return HUF_decompress4X2_usingDTable_internal_bmi2_asm(dst, dstSize, cSrc, cSrcSize, DTable); -#else - return HUF_decompress4X2_usingDTable_internal_default(dst, dstSize, cSrc, cSrcSize, DTable); + if (!(flags & HUF_flags_disableAsm)) { + loopFn = HUF_decompress4X2_usingDTable_internal_fast_asm_loop; + } #endif + + if (HUF_ENABLE_FAST_DECODE && !(flags & HUF_flags_disableFast)) { + size_t const ret = HUF_decompress4X2_usingDTable_internal_fast(dst, dstSize, cSrc, cSrcSize, DTable, loopFn); + if (ret != 0) + return ret; + } + return fallbackFn(dst, dstSize, cSrc, cSrcSize, DTable); } HUF_DGEN(HUF_decompress1X2_usingDTable_internal) -size_t HUF_decompress1X2_usingDTable( - void* dst, size_t dstSize, - const void* cSrc, size_t cSrcSize, - const HUF_DTable* DTable) -{ - DTableDesc dtd = HUF_getDTableDesc(DTable); - if (dtd.tableType != 1) return ERROR(GENERIC); - return HUF_decompress1X2_usingDTable_internal(dst, dstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); -} - size_t HUF_decompress1X2_DCtx_wksp(HUF_DTable* DCtx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, - void* workSpace, size_t wkspSize) + void* workSpace, size_t wkspSize, int flags) { const BYTE* ip = (const BYTE*) cSrc; size_t const hSize = HUF_readDTableX2_wksp(DCtx, cSrc, cSrcSize, - workSpace, wkspSize); + workSpace, wkspSize, flags); if (HUF_isError(hSize)) return hSize; if (hSize >= cSrcSize) return ERROR(srcSize_wrong); ip += hSize; cSrcSize -= hSize; - return HUF_decompress1X2_usingDTable_internal(dst, dstSize, ip, cSrcSize, DCtx, /* bmi2 */ 0); + return HUF_decompress1X2_usingDTable_internal(dst, dstSize, ip, cSrcSize, DCtx, flags); } - -size_t HUF_decompress4X2_usingDTable( - void* dst, size_t dstSize, - const void* cSrc, size_t cSrcSize, - const HUF_DTable* DTable) -{ - DTableDesc dtd = HUF_getDTableDesc(DTable); - if (dtd.tableType != 1) return ERROR(GENERIC); - return HUF_decompress4X2_usingDTable_internal(dst, dstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); -} - -static size_t HUF_decompress4X2_DCtx_wksp_bmi2(HUF_DTable* dctx, void* dst, size_t dstSize, +static size_t HUF_decompress4X2_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, - void* workSpace, size_t wkspSize, int bmi2) + void* workSpace, size_t wkspSize, int flags) { const BYTE* ip = (const BYTE*) cSrc; size_t hSize = HUF_readDTableX2_wksp(dctx, cSrc, cSrcSize, - workSpace, wkspSize); + workSpace, wkspSize, flags); if (HUF_isError(hSize)) return hSize; if (hSize >= cSrcSize) return ERROR(srcSize_wrong); ip += hSize; cSrcSize -= hSize; - return HUF_decompress4X2_usingDTable_internal(dst, dstSize, ip, cSrcSize, dctx, bmi2); + return HUF_decompress4X2_usingDTable_internal(dst, dstSize, ip, cSrcSize, dctx, flags); } -size_t HUF_decompress4X2_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, - const void* cSrc, size_t cSrcSize, - void* workSpace, size_t wkspSize) -{ - return HUF_decompress4X2_DCtx_wksp_bmi2(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, /* bmi2 */ 0); -} - - #endif /* HUF_FORCE_DECOMPRESS_X1 */ @@ -1518,44 +1785,6 @@ size_t HUF_decompress4X2_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, /* Universal decompression selectors */ /* ***********************************/ -size_t HUF_decompress1X_usingDTable(void* dst, size_t maxDstSize, - const void* cSrc, size_t cSrcSize, - const HUF_DTable* DTable) -{ - DTableDesc const dtd = HUF_getDTableDesc(DTable); -#if defined(HUF_FORCE_DECOMPRESS_X1) - (void)dtd; - assert(dtd.tableType == 0); - return HUF_decompress1X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); -#elif defined(HUF_FORCE_DECOMPRESS_X2) - (void)dtd; - assert(dtd.tableType == 1); - return HUF_decompress1X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); -#else - return dtd.tableType ? HUF_decompress1X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0) : - HUF_decompress1X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); -#endif -} - -size_t HUF_decompress4X_usingDTable(void* dst, size_t maxDstSize, - const void* cSrc, size_t cSrcSize, - const HUF_DTable* DTable) -{ - DTableDesc const dtd = HUF_getDTableDesc(DTable); -#if defined(HUF_FORCE_DECOMPRESS_X1) - (void)dtd; - assert(dtd.tableType == 0); - return HUF_decompress4X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); -#elif defined(HUF_FORCE_DECOMPRESS_X2) - (void)dtd; - assert(dtd.tableType == 1); - return HUF_decompress4X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); -#else - return dtd.tableType ? HUF_decompress4X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0) : - HUF_decompress4X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); -#endif -} - #if !defined(HUF_FORCE_DECOMPRESS_X1) && !defined(HUF_FORCE_DECOMPRESS_X2) typedef struct { U32 tableTime; U32 decode256Time; } algo_time_t; @@ -1610,36 +1839,9 @@ U32 HUF_selectDecoder (size_t dstSize, size_t cSrcSize) #endif } - -size_t HUF_decompress4X_hufOnly_wksp(HUF_DTable* dctx, void* dst, - size_t dstSize, const void* cSrc, - size_t cSrcSize, void* workSpace, - size_t wkspSize) -{ - /* validation checks */ - if (dstSize == 0) return ERROR(dstSize_tooSmall); - if (cSrcSize == 0) return ERROR(corruption_detected); - - { U32 const algoNb = HUF_selectDecoder(dstSize, cSrcSize); -#if defined(HUF_FORCE_DECOMPRESS_X1) - (void)algoNb; - assert(algoNb == 0); - return HUF_decompress4X1_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize); -#elif defined(HUF_FORCE_DECOMPRESS_X2) - (void)algoNb; - assert(algoNb == 1); - return HUF_decompress4X2_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize); -#else - return algoNb ? HUF_decompress4X2_DCtx_wksp(dctx, dst, dstSize, cSrc, - cSrcSize, workSpace, wkspSize): - HUF_decompress4X1_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize); -#endif - } -} - size_t HUF_decompress1X_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, - void* workSpace, size_t wkspSize) + void* workSpace, size_t wkspSize, int flags) { /* validation checks */ if (dstSize == 0) return ERROR(dstSize_tooSmall); @@ -1652,71 +1854,71 @@ size_t HUF_decompress1X_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, (void)algoNb; assert(algoNb == 0); return HUF_decompress1X1_DCtx_wksp(dctx, dst, dstSize, cSrc, - cSrcSize, workSpace, wkspSize); + cSrcSize, workSpace, wkspSize, flags); #elif defined(HUF_FORCE_DECOMPRESS_X2) (void)algoNb; assert(algoNb == 1); return HUF_decompress1X2_DCtx_wksp(dctx, dst, dstSize, cSrc, - cSrcSize, workSpace, wkspSize); + cSrcSize, workSpace, wkspSize, flags); #else return algoNb ? HUF_decompress1X2_DCtx_wksp(dctx, dst, dstSize, cSrc, - cSrcSize, workSpace, wkspSize): + cSrcSize, workSpace, wkspSize, flags): HUF_decompress1X1_DCtx_wksp(dctx, dst, dstSize, cSrc, - cSrcSize, workSpace, wkspSize); + cSrcSize, workSpace, wkspSize, flags); #endif } } -size_t HUF_decompress1X_usingDTable_bmi2(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int bmi2) +size_t HUF_decompress1X_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int flags) { DTableDesc const dtd = HUF_getDTableDesc(DTable); #if defined(HUF_FORCE_DECOMPRESS_X1) (void)dtd; assert(dtd.tableType == 0); - return HUF_decompress1X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2); + return HUF_decompress1X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags); #elif defined(HUF_FORCE_DECOMPRESS_X2) (void)dtd; assert(dtd.tableType == 1); - return HUF_decompress1X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2); + return HUF_decompress1X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags); #else - return dtd.tableType ? HUF_decompress1X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2) : - HUF_decompress1X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2); + return dtd.tableType ? HUF_decompress1X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags) : + HUF_decompress1X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags); #endif } #ifndef HUF_FORCE_DECOMPRESS_X2 -size_t HUF_decompress1X1_DCtx_wksp_bmi2(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int bmi2) +size_t HUF_decompress1X1_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int flags) { const BYTE* ip = (const BYTE*) cSrc; - size_t const hSize = HUF_readDTableX1_wksp_bmi2(dctx, cSrc, cSrcSize, workSpace, wkspSize, bmi2); + size_t const hSize = HUF_readDTableX1_wksp(dctx, cSrc, cSrcSize, workSpace, wkspSize, flags); if (HUF_isError(hSize)) return hSize; if (hSize >= cSrcSize) return ERROR(srcSize_wrong); ip += hSize; cSrcSize -= hSize; - return HUF_decompress1X1_usingDTable_internal(dst, dstSize, ip, cSrcSize, dctx, bmi2); + return HUF_decompress1X1_usingDTable_internal(dst, dstSize, ip, cSrcSize, dctx, flags); } #endif -size_t HUF_decompress4X_usingDTable_bmi2(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int bmi2) +size_t HUF_decompress4X_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int flags) { DTableDesc const dtd = HUF_getDTableDesc(DTable); #if defined(HUF_FORCE_DECOMPRESS_X1) (void)dtd; assert(dtd.tableType == 0); - return HUF_decompress4X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2); + return HUF_decompress4X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags); #elif defined(HUF_FORCE_DECOMPRESS_X2) (void)dtd; assert(dtd.tableType == 1); - return HUF_decompress4X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2); + return HUF_decompress4X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags); #else - return dtd.tableType ? HUF_decompress4X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2) : - HUF_decompress4X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2); + return dtd.tableType ? HUF_decompress4X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags) : + HUF_decompress4X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags); #endif } -size_t HUF_decompress4X_hufOnly_wksp_bmi2(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int bmi2) +size_t HUF_decompress4X_hufOnly_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int flags) { /* validation checks */ if (dstSize == 0) return ERROR(dstSize_tooSmall); @@ -1726,15 +1928,14 @@ size_t HUF_decompress4X_hufOnly_wksp_bmi2(HUF_DTable* dctx, void* dst, size_t ds #if defined(HUF_FORCE_DECOMPRESS_X1) (void)algoNb; assert(algoNb == 0); - return HUF_decompress4X1_DCtx_wksp_bmi2(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, bmi2); + return HUF_decompress4X1_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, flags); #elif defined(HUF_FORCE_DECOMPRESS_X2) (void)algoNb; assert(algoNb == 1); - return HUF_decompress4X2_DCtx_wksp_bmi2(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, bmi2); + return HUF_decompress4X2_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, flags); #else - return algoNb ? HUF_decompress4X2_DCtx_wksp_bmi2(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, bmi2) : - HUF_decompress4X1_DCtx_wksp_bmi2(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, bmi2); + return algoNb ? HUF_decompress4X2_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, flags) : + HUF_decompress4X1_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, flags); #endif } } - diff --git a/lib/zstd/decompress/zstd_ddict.c b/lib/zstd/decompress/zstd_ddict.c index dbbc7919de53..30ef65e1ab5c 100644 --- a/lib/zstd/decompress/zstd_ddict.c +++ b/lib/zstd/decompress/zstd_ddict.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -14,12 +15,12 @@ /*-******************************************************* * Dependencies *********************************************************/ +#include "../common/allocations.h" /* ZSTD_customMalloc, ZSTD_customFree */ #include "../common/zstd_deps.h" /* ZSTD_memcpy, ZSTD_memmove, ZSTD_memset */ #include "../common/cpu.h" /* bmi2 */ #include "../common/mem.h" /* low level memory routines */ #define FSE_STATIC_LINKING_ONLY #include "../common/fse.h" -#define HUF_STATIC_LINKING_ONLY #include "../common/huf.h" #include "zstd_decompress_internal.h" #include "zstd_ddict.h" @@ -131,7 +132,7 @@ static size_t ZSTD_initDDict_internal(ZSTD_DDict* ddict, ZSTD_memcpy(internalBuffer, dict, dictSize); } ddict->dictSize = dictSize; - ddict->entropy.hufTable[0] = (HUF_DTable)((HufLog)*0x1000001); /* cover both little and big endian */ + ddict->entropy.hufTable[0] = (HUF_DTable)((ZSTD_HUFFDTABLE_CAPACITY_LOG)*0x1000001); /* cover both little and big endian */ /* parse dictionary content */ FORWARD_IF_ERROR( ZSTD_loadEntropy_intoDDict(ddict, dictContentType) , ""); @@ -237,5 +238,5 @@ size_t ZSTD_sizeof_DDict(const ZSTD_DDict* ddict) unsigned ZSTD_getDictID_fromDDict(const ZSTD_DDict* ddict) { if (ddict==NULL) return 0; - return ZSTD_getDictID_fromDict(ddict->dictContent, ddict->dictSize); + return ddict->dictID; } diff --git a/lib/zstd/decompress/zstd_ddict.h b/lib/zstd/decompress/zstd_ddict.h index 8c1a79d666f8..de459a0dacd1 100644 --- a/lib/zstd/decompress/zstd_ddict.h +++ b/lib/zstd/decompress/zstd_ddict.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the diff --git a/lib/zstd/decompress/zstd_decompress.c b/lib/zstd/decompress/zstd_decompress.c index 6b3177c94711..da8b4cf116e3 100644 --- a/lib/zstd/decompress/zstd_decompress.c +++ b/lib/zstd/decompress/zstd_decompress.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -53,13 +54,15 @@ * Dependencies *********************************************************/ #include "../common/zstd_deps.h" /* ZSTD_memcpy, ZSTD_memmove, ZSTD_memset */ +#include "../common/allocations.h" /* ZSTD_customMalloc, ZSTD_customCalloc, ZSTD_customFree */ +#include "../common/error_private.h" +#include "../common/zstd_internal.h" /* blockProperties_t */ #include "../common/mem.h" /* low level memory routines */ +#include "../common/bits.h" /* ZSTD_highbit32 */ #define FSE_STATIC_LINKING_ONLY #include "../common/fse.h" -#define HUF_STATIC_LINKING_ONLY #include "../common/huf.h" #include /* xxh64_reset, xxh64_update, xxh64_digest, XXH64 */ -#include "../common/zstd_internal.h" /* blockProperties_t */ #include "zstd_decompress_internal.h" /* ZSTD_DCtx */ #include "zstd_ddict.h" /* ZSTD_DDictDictContent */ #include "zstd_decompress_block.h" /* ZSTD_decompressBlock_internal */ @@ -72,11 +75,11 @@ *************************************/ #define DDICT_HASHSET_MAX_LOAD_FACTOR_COUNT_MULT 4 -#define DDICT_HASHSET_MAX_LOAD_FACTOR_SIZE_MULT 3 /* These two constants represent SIZE_MULT/COUNT_MULT load factor without using a float. - * Currently, that means a 0.75 load factor. - * So, if count * COUNT_MULT / size * SIZE_MULT != 0, then we've exceeded - * the load factor of the ddict hash set. - */ +#define DDICT_HASHSET_MAX_LOAD_FACTOR_SIZE_MULT 3 /* These two constants represent SIZE_MULT/COUNT_MULT load factor without using a float. + * Currently, that means a 0.75 load factor. + * So, if count * COUNT_MULT / size * SIZE_MULT != 0, then we've exceeded + * the load factor of the ddict hash set. + */ #define DDICT_HASHSET_TABLE_BASE_SIZE 64 #define DDICT_HASHSET_RESIZE_FACTOR 2 @@ -237,6 +240,8 @@ static void ZSTD_DCtx_resetParameters(ZSTD_DCtx* dctx) dctx->outBufferMode = ZSTD_bm_buffered; dctx->forceIgnoreChecksum = ZSTD_d_validateChecksum; dctx->refMultipleDDicts = ZSTD_rmd_refSingleDDict; + dctx->disableHufAsm = 0; + dctx->maxBlockSizeParam = 0; } static void ZSTD_initDCtx_internal(ZSTD_DCtx* dctx) @@ -253,6 +258,7 @@ static void ZSTD_initDCtx_internal(ZSTD_DCtx* dctx) dctx->streamStage = zdss_init; dctx->noForwardProgress = 0; dctx->oversizedDuration = 0; + dctx->isFrameDecompression = 1; #if DYNAMIC_BMI2 dctx->bmi2 = ZSTD_cpuSupportsBmi2(); #endif @@ -421,16 +427,40 @@ size_t ZSTD_frameHeaderSize(const void* src, size_t srcSize) * note : only works for formats ZSTD_f_zstd1 and ZSTD_f_zstd1_magicless * @return : 0, `zfhPtr` is correctly filled, * >0, `srcSize` is too small, value is wanted `srcSize` amount, - * or an error code, which can be tested using ZSTD_isError() */ -size_t ZSTD_getFrameHeader_advanced(ZSTD_frameHeader* zfhPtr, const void* src, size_t srcSize, ZSTD_format_e format) +** or an error code, which can be tested using ZSTD_isError() */ +size_t ZSTD_getFrameHeader_advanced(ZSTD_FrameHeader* zfhPtr, const void* src, size_t srcSize, ZSTD_format_e format) { const BYTE* ip = (const BYTE*)src; size_t const minInputSize = ZSTD_startingInputLength(format); - ZSTD_memset(zfhPtr, 0, sizeof(*zfhPtr)); /* not strictly necessary, but static analyzer do not understand that zfhPtr is only going to be read only if return value is zero, since they are 2 different signals */ - if (srcSize < minInputSize) return minInputSize; - RETURN_ERROR_IF(src==NULL, GENERIC, "invalid parameter"); + DEBUGLOG(5, "ZSTD_getFrameHeader_advanced: minInputSize = %zu, srcSize = %zu", minInputSize, srcSize); + + if (srcSize > 0) { + /* note : technically could be considered an assert(), since it's an invalid entry */ + RETURN_ERROR_IF(src==NULL, GENERIC, "invalid parameter : src==NULL, but srcSize>0"); + } + if (srcSize < minInputSize) { + if (srcSize > 0 && format != ZSTD_f_zstd1_magicless) { + /* when receiving less than @minInputSize bytes, + * control these bytes at least correspond to a supported magic number + * in order to error out early if they don't. + **/ + size_t const toCopy = MIN(4, srcSize); + unsigned char hbuf[4]; MEM_writeLE32(hbuf, ZSTD_MAGICNUMBER); + assert(src != NULL); + ZSTD_memcpy(hbuf, src, toCopy); + if ( MEM_readLE32(hbuf) != ZSTD_MAGICNUMBER ) { + /* not a zstd frame : let's check if it's a skippable frame */ + MEM_writeLE32(hbuf, ZSTD_MAGIC_SKIPPABLE_START); + ZSTD_memcpy(hbuf, src, toCopy); + if ((MEM_readLE32(hbuf) & ZSTD_MAGIC_SKIPPABLE_MASK) != ZSTD_MAGIC_SKIPPABLE_START) { + RETURN_ERROR(prefix_unknown, + "first bytes don't correspond to any supported magic number"); + } } } + return minInputSize; + } + ZSTD_memset(zfhPtr, 0, sizeof(*zfhPtr)); /* not strictly necessary, but static analyzers may not understand that zfhPtr will be read only if return value is zero, since they are 2 different signals */ if ( (format != ZSTD_f_zstd1_magicless) && (MEM_readLE32(src) != ZSTD_MAGICNUMBER) ) { if ((MEM_readLE32(src) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) { @@ -438,8 +468,10 @@ size_t ZSTD_getFrameHeader_advanced(ZSTD_frameHeader* zfhPtr, const void* src, s if (srcSize < ZSTD_SKIPPABLEHEADERSIZE) return ZSTD_SKIPPABLEHEADERSIZE; /* magic number + frame length */ ZSTD_memset(zfhPtr, 0, sizeof(*zfhPtr)); - zfhPtr->frameContentSize = MEM_readLE32((const char *)src + ZSTD_FRAMEIDSIZE); zfhPtr->frameType = ZSTD_skippableFrame; + zfhPtr->dictID = MEM_readLE32(src) - ZSTD_MAGIC_SKIPPABLE_START; + zfhPtr->headerSize = ZSTD_SKIPPABLEHEADERSIZE; + zfhPtr->frameContentSize = MEM_readLE32((const char *)src + ZSTD_FRAMEIDSIZE); return 0; } RETURN_ERROR(prefix_unknown, ""); @@ -508,7 +540,7 @@ size_t ZSTD_getFrameHeader_advanced(ZSTD_frameHeader* zfhPtr, const void* src, s * @return : 0, `zfhPtr` is correctly filled, * >0, `srcSize` is too small, value is wanted `srcSize` amount, * or an error code, which can be tested using ZSTD_isError() */ -size_t ZSTD_getFrameHeader(ZSTD_frameHeader* zfhPtr, const void* src, size_t srcSize) +size_t ZSTD_getFrameHeader(ZSTD_FrameHeader* zfhPtr, const void* src, size_t srcSize) { return ZSTD_getFrameHeader_advanced(zfhPtr, src, srcSize, ZSTD_f_zstd1); } @@ -520,7 +552,7 @@ size_t ZSTD_getFrameHeader(ZSTD_frameHeader* zfhPtr, const void* src, size_t src * - ZSTD_CONTENTSIZE_ERROR if an error occurred (e.g. invalid magic number, srcSize too small) */ unsigned long long ZSTD_getFrameContentSize(const void *src, size_t srcSize) { - { ZSTD_frameHeader zfh; + { ZSTD_FrameHeader zfh; if (ZSTD_getFrameHeader(&zfh, src, srcSize) != 0) return ZSTD_CONTENTSIZE_ERROR; if (zfh.frameType == ZSTD_skippableFrame) { @@ -540,61 +572,62 @@ static size_t readSkippableFrameSize(void const* src, size_t srcSize) sizeU32 = MEM_readLE32((BYTE const*)src + ZSTD_FRAMEIDSIZE); RETURN_ERROR_IF((U32)(sizeU32 + ZSTD_SKIPPABLEHEADERSIZE) < sizeU32, frameParameter_unsupported, ""); - { - size_t const skippableSize = skippableHeaderSize + sizeU32; + { size_t const skippableSize = skippableHeaderSize + sizeU32; RETURN_ERROR_IF(skippableSize > srcSize, srcSize_wrong, ""); return skippableSize; } } /*! ZSTD_readSkippableFrame() : - * Retrieves a zstd skippable frame containing data given by src, and writes it to dst buffer. + * Retrieves content of a skippable frame, and writes it to dst buffer. * * The parameter magicVariant will receive the magicVariant that was supplied when the frame was written, * i.e. magicNumber - ZSTD_MAGIC_SKIPPABLE_START. This can be NULL if the caller is not interested * in the magicVariant. * - * Returns an error if destination buffer is not large enough, or if the frame is not skippable. + * Returns an error if destination buffer is not large enough, or if this is not a valid skippable frame. * * @return : number of bytes written or a ZSTD error. */ -ZSTDLIB_API size_t ZSTD_readSkippableFrame(void* dst, size_t dstCapacity, unsigned* magicVariant, - const void* src, size_t srcSize) +size_t ZSTD_readSkippableFrame(void* dst, size_t dstCapacity, + unsigned* magicVariant, /* optional, can be NULL */ + const void* src, size_t srcSize) { - U32 const magicNumber = MEM_readLE32(src); - size_t skippableFrameSize = readSkippableFrameSize(src, srcSize); - size_t skippableContentSize = skippableFrameSize - ZSTD_SKIPPABLEHEADERSIZE; - - /* check input validity */ - RETURN_ERROR_IF(!ZSTD_isSkippableFrame(src, srcSize), frameParameter_unsupported, ""); - RETURN_ERROR_IF(skippableFrameSize < ZSTD_SKIPPABLEHEADERSIZE || skippableFrameSize > srcSize, srcSize_wrong, ""); - RETURN_ERROR_IF(skippableContentSize > dstCapacity, dstSize_tooSmall, ""); + RETURN_ERROR_IF(srcSize < ZSTD_SKIPPABLEHEADERSIZE, srcSize_wrong, ""); - /* deliver payload */ - if (skippableContentSize > 0 && dst != NULL) - ZSTD_memcpy(dst, (const BYTE *)src + ZSTD_SKIPPABLEHEADERSIZE, skippableContentSize); - if (magicVariant != NULL) - *magicVariant = magicNumber - ZSTD_MAGIC_SKIPPABLE_START; - return skippableContentSize; + { U32 const magicNumber = MEM_readLE32(src); + size_t skippableFrameSize = readSkippableFrameSize(src, srcSize); + size_t skippableContentSize = skippableFrameSize - ZSTD_SKIPPABLEHEADERSIZE; + + /* check input validity */ + RETURN_ERROR_IF(!ZSTD_isSkippableFrame(src, srcSize), frameParameter_unsupported, ""); + RETURN_ERROR_IF(skippableFrameSize < ZSTD_SKIPPABLEHEADERSIZE || skippableFrameSize > srcSize, srcSize_wrong, ""); + RETURN_ERROR_IF(skippableContentSize > dstCapacity, dstSize_tooSmall, ""); + + /* deliver payload */ + if (skippableContentSize > 0 && dst != NULL) + ZSTD_memcpy(dst, (const BYTE *)src + ZSTD_SKIPPABLEHEADERSIZE, skippableContentSize); + if (magicVariant != NULL) + *magicVariant = magicNumber - ZSTD_MAGIC_SKIPPABLE_START; + return skippableContentSize; + } } /* ZSTD_findDecompressedSize() : - * compatible with legacy mode * `srcSize` must be the exact length of some number of ZSTD compressed and/or * skippable frames - * @return : decompressed size of the frames contained */ + * note: compatible with legacy mode + * @return : decompressed size of the frames contained */ unsigned long long ZSTD_findDecompressedSize(const void* src, size_t srcSize) { - unsigned long long totalDstSize = 0; + U64 totalDstSize = 0; while (srcSize >= ZSTD_startingInputLength(ZSTD_f_zstd1)) { U32 const magicNumber = MEM_readLE32(src); if ((magicNumber & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) { size_t const skippableSize = readSkippableFrameSize(src, srcSize); - if (ZSTD_isError(skippableSize)) { - return ZSTD_CONTENTSIZE_ERROR; - } + if (ZSTD_isError(skippableSize)) return ZSTD_CONTENTSIZE_ERROR; assert(skippableSize <= srcSize); src = (const BYTE *)src + skippableSize; @@ -602,17 +635,17 @@ unsigned long long ZSTD_findDecompressedSize(const void* src, size_t srcSize) continue; } - { unsigned long long const ret = ZSTD_getFrameContentSize(src, srcSize); - if (ret >= ZSTD_CONTENTSIZE_ERROR) return ret; + { unsigned long long const fcs = ZSTD_getFrameContentSize(src, srcSize); + if (fcs >= ZSTD_CONTENTSIZE_ERROR) return fcs; - /* check for overflow */ - if (totalDstSize + ret < totalDstSize) return ZSTD_CONTENTSIZE_ERROR; - totalDstSize += ret; + if (U64_MAX - totalDstSize < fcs) + return ZSTD_CONTENTSIZE_ERROR; /* check for overflow */ + totalDstSize += fcs; } + /* skip to next frame */ { size_t const frameSrcSize = ZSTD_findFrameCompressedSize(src, srcSize); - if (ZSTD_isError(frameSrcSize)) { - return ZSTD_CONTENTSIZE_ERROR; - } + if (ZSTD_isError(frameSrcSize)) return ZSTD_CONTENTSIZE_ERROR; + assert(frameSrcSize <= srcSize); src = (const BYTE *)src + frameSrcSize; srcSize -= frameSrcSize; @@ -676,13 +709,13 @@ static ZSTD_frameSizeInfo ZSTD_errorFrameSizeInfo(size_t ret) return frameSizeInfo; } -static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize) +static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize, ZSTD_format_e format) { ZSTD_frameSizeInfo frameSizeInfo; ZSTD_memset(&frameSizeInfo, 0, sizeof(ZSTD_frameSizeInfo)); - if ((srcSize >= ZSTD_SKIPPABLEHEADERSIZE) + if (format == ZSTD_f_zstd1 && (srcSize >= ZSTD_SKIPPABLEHEADERSIZE) && (MEM_readLE32(src) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) { frameSizeInfo.compressedSize = readSkippableFrameSize(src, srcSize); assert(ZSTD_isError(frameSizeInfo.compressedSize) || @@ -693,10 +726,10 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize const BYTE* const ipstart = ip; size_t remainingSize = srcSize; size_t nbBlocks = 0; - ZSTD_frameHeader zfh; + ZSTD_FrameHeader zfh; /* Extract Frame Header */ - { size_t const ret = ZSTD_getFrameHeader(&zfh, src, srcSize); + { size_t const ret = ZSTD_getFrameHeader_advanced(&zfh, src, srcSize, format); if (ZSTD_isError(ret)) return ZSTD_errorFrameSizeInfo(ret); if (ret > 0) @@ -730,28 +763,31 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize ip += 4; } + frameSizeInfo.nbBlocks = nbBlocks; frameSizeInfo.compressedSize = (size_t)(ip - ipstart); frameSizeInfo.decompressedBound = (zfh.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN) ? zfh.frameContentSize - : nbBlocks * zfh.blockSizeMax; + : (unsigned long long)nbBlocks * zfh.blockSizeMax; return frameSizeInfo; } } +static size_t ZSTD_findFrameCompressedSize_advanced(const void *src, size_t srcSize, ZSTD_format_e format) { + ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize, format); + return frameSizeInfo.compressedSize; +} + /* ZSTD_findFrameCompressedSize() : - * compatible with legacy mode - * `src` must point to the start of a ZSTD frame, ZSTD legacy frame, or skippable frame - * `srcSize` must be at least as large as the frame contained - * @return : the compressed size of the frame starting at `src` */ + * See docs in zstd.h + * Note: compatible with legacy mode */ size_t ZSTD_findFrameCompressedSize(const void *src, size_t srcSize) { - ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize); - return frameSizeInfo.compressedSize; + return ZSTD_findFrameCompressedSize_advanced(src, srcSize, ZSTD_f_zstd1); } /* ZSTD_decompressBound() : * compatible with legacy mode - * `src` must point to the start of a ZSTD frame or a skippeable frame + * `src` must point to the start of a ZSTD frame or a skippable frame * `srcSize` must be at least as large as the frame contained * @return : the maximum decompressed size of the compressed source */ @@ -760,7 +796,7 @@ unsigned long long ZSTD_decompressBound(const void* src, size_t srcSize) unsigned long long bound = 0; /* Iterate over each frame */ while (srcSize > 0) { - ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize); + ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize, ZSTD_f_zstd1); size_t const compressedSize = frameSizeInfo.compressedSize; unsigned long long const decompressedBound = frameSizeInfo.decompressedBound; if (ZSTD_isError(compressedSize) || decompressedBound == ZSTD_CONTENTSIZE_ERROR) @@ -773,6 +809,48 @@ unsigned long long ZSTD_decompressBound(const void* src, size_t srcSize) return bound; } +size_t ZSTD_decompressionMargin(void const* src, size_t srcSize) +{ + size_t margin = 0; + unsigned maxBlockSize = 0; + + /* Iterate over each frame */ + while (srcSize > 0) { + ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize, ZSTD_f_zstd1); + size_t const compressedSize = frameSizeInfo.compressedSize; + unsigned long long const decompressedBound = frameSizeInfo.decompressedBound; + ZSTD_FrameHeader zfh; + + FORWARD_IF_ERROR(ZSTD_getFrameHeader(&zfh, src, srcSize), ""); + if (ZSTD_isError(compressedSize) || decompressedBound == ZSTD_CONTENTSIZE_ERROR) + return ERROR(corruption_detected); + + if (zfh.frameType == ZSTD_frame) { + /* Add the frame header to our margin */ + margin += zfh.headerSize; + /* Add the checksum to our margin */ + margin += zfh.checksumFlag ? 4 : 0; + /* Add 3 bytes per block */ + margin += 3 * frameSizeInfo.nbBlocks; + + /* Compute the max block size */ + maxBlockSize = MAX(maxBlockSize, zfh.blockSizeMax); + } else { + assert(zfh.frameType == ZSTD_skippableFrame); + /* Add the entire skippable frame size to our margin. */ + margin += compressedSize; + } + + assert(srcSize >= compressedSize); + src = (const BYTE*)src + compressedSize; + srcSize -= compressedSize; + } + + /* Add the max block size back to the margin. */ + margin += maxBlockSize; + + return margin; +} /*-************************************************************* * Frame decoding @@ -815,7 +893,7 @@ static size_t ZSTD_setRleBlock(void* dst, size_t dstCapacity, return regenSize; } -static void ZSTD_DCtx_trace_end(ZSTD_DCtx const* dctx, U64 uncompressedSize, U64 compressedSize, unsigned streaming) +static void ZSTD_DCtx_trace_end(ZSTD_DCtx const* dctx, U64 uncompressedSize, U64 compressedSize, int streaming) { (void)dctx; (void)uncompressedSize; @@ -856,6 +934,10 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx, ip += frameHeaderSize; remainingSrcSize -= frameHeaderSize; } + /* Shrink the blockSizeMax if enabled */ + if (dctx->maxBlockSizeParam != 0) + dctx->fParams.blockSizeMax = MIN(dctx->fParams.blockSizeMax, (unsigned)dctx->maxBlockSizeParam); + /* Loop on each block */ while (1) { BYTE* oBlockEnd = oend; @@ -888,7 +970,8 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx, switch(blockProperties.blockType) { case bt_compressed: - decodedSize = ZSTD_decompressBlock_internal(dctx, op, (size_t)(oBlockEnd-op), ip, cBlockSize, /* frame */ 1, not_streaming); + assert(dctx->isFrameDecompression == 1); + decodedSize = ZSTD_decompressBlock_internal(dctx, op, (size_t)(oBlockEnd-op), ip, cBlockSize, not_streaming); break; case bt_raw : /* Use oend instead of oBlockEnd because this function is safe to overlap. It uses memmove. */ @@ -901,12 +984,14 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx, default: RETURN_ERROR(corruption_detected, "invalid block type"); } - - if (ZSTD_isError(decodedSize)) return decodedSize; - if (dctx->validateChecksum) + FORWARD_IF_ERROR(decodedSize, "Block decompression failure"); + DEBUGLOG(5, "Decompressed block of dSize = %u", (unsigned)decodedSize); + if (dctx->validateChecksum) { xxh64_update(&dctx->xxhState, op, decodedSize); - if (decodedSize != 0) + } + if (decodedSize) /* support dst = NULL,0 */ { op += decodedSize; + } assert(ip != NULL); ip += cBlockSize; remainingSrcSize -= cBlockSize; @@ -930,12 +1015,15 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx, } ZSTD_DCtx_trace_end(dctx, (U64)(op-ostart), (U64)(ip-istart), /* streaming */ 0); /* Allow caller to get size read */ + DEBUGLOG(4, "ZSTD_decompressFrame: decompressed frame of size %i, consuming %i bytes of input", (int)(op-ostart), (int)(ip - (const BYTE*)*srcPtr)); *srcPtr = ip; *srcSizePtr = remainingSrcSize; return (size_t)(op-ostart); } -static size_t ZSTD_decompressMultiFrame(ZSTD_DCtx* dctx, +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +size_t ZSTD_decompressMultiFrame(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize, const void* dict, size_t dictSize, @@ -955,17 +1043,18 @@ static size_t ZSTD_decompressMultiFrame(ZSTD_DCtx* dctx, while (srcSize >= ZSTD_startingInputLength(dctx->format)) { - { U32 const magicNumber = MEM_readLE32(src); - DEBUGLOG(4, "reading magic number %08X (expecting %08X)", - (unsigned)magicNumber, ZSTD_MAGICNUMBER); + if (dctx->format == ZSTD_f_zstd1 && srcSize >= 4) { + U32 const magicNumber = MEM_readLE32(src); + DEBUGLOG(5, "reading magic number %08X", (unsigned)magicNumber); if ((magicNumber & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) { + /* skippable frame detected : skip it */ size_t const skippableSize = readSkippableFrameSize(src, srcSize); - FORWARD_IF_ERROR(skippableSize, "readSkippableFrameSize failed"); + FORWARD_IF_ERROR(skippableSize, "invalid skippable frame"); assert(skippableSize <= srcSize); src = (const BYTE *)src + skippableSize; srcSize -= skippableSize; - continue; + continue; /* check next frame */ } } if (ddict) { @@ -1061,8 +1150,8 @@ size_t ZSTD_decompress(void* dst, size_t dstCapacity, const void* src, size_t sr size_t ZSTD_nextSrcSizeToDecompress(ZSTD_DCtx* dctx) { return dctx->expected; } /* - * Similar to ZSTD_nextSrcSizeToDecompress(), but when a block input can be streamed, - * we allow taking a partial block as the input. Currently only raw uncompressed blocks can + * Similar to ZSTD_nextSrcSizeToDecompress(), but when a block input can be streamed, we + * allow taking a partial block as the input. Currently only raw uncompressed blocks can * be streamed. * * For blocks that can be streamed, this allows us to reduce the latency until we produce @@ -1181,7 +1270,8 @@ size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, c { case bt_compressed: DEBUGLOG(5, "ZSTD_decompressContinue: case bt_compressed"); - rSize = ZSTD_decompressBlock_internal(dctx, dst, dstCapacity, src, srcSize, /* frame */ 1, is_streaming); + assert(dctx->isFrameDecompression == 1); + rSize = ZSTD_decompressBlock_internal(dctx, dst, dstCapacity, src, srcSize, is_streaming); dctx->expected = 0; /* Streaming not supported */ break; case bt_raw : @@ -1250,6 +1340,7 @@ size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, c case ZSTDds_decodeSkippableHeader: assert(src != NULL); assert(srcSize <= ZSTD_SKIPPABLEHEADERSIZE); + assert(dctx->format != ZSTD_f_zstd1_magicless); ZSTD_memcpy(dctx->headerBuffer + (ZSTD_SKIPPABLEHEADERSIZE - srcSize), src, srcSize); /* complete skippable header */ dctx->expected = MEM_readLE32(dctx->headerBuffer + ZSTD_FRAMEIDSIZE); /* note : dctx->expected can grow seriously large, beyond local buffer size */ dctx->stage = ZSTDds_skipFrame; @@ -1262,7 +1353,7 @@ size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, c default: assert(0); /* impossible */ - RETURN_ERROR(GENERIC, "impossible to reach"); /* some compiler require default to do something */ + RETURN_ERROR(GENERIC, "impossible to reach"); /* some compilers require default to do something */ } } @@ -1303,11 +1394,11 @@ ZSTD_loadDEntropy(ZSTD_entropyDTables_t* entropy, /* in minimal huffman, we always use X1 variants */ size_t const hSize = HUF_readDTableX1_wksp(entropy->hufTable, dictPtr, dictEnd - dictPtr, - workspace, workspaceSize); + workspace, workspaceSize, /* flags */ 0); #else size_t const hSize = HUF_readDTableX2_wksp(entropy->hufTable, dictPtr, (size_t)(dictEnd - dictPtr), - workspace, workspaceSize); + workspace, workspaceSize, /* flags */ 0); #endif RETURN_ERROR_IF(HUF_isError(hSize), dictionary_corrupted, ""); dictPtr += hSize; @@ -1403,10 +1494,11 @@ size_t ZSTD_decompressBegin(ZSTD_DCtx* dctx) dctx->prefixStart = NULL; dctx->virtualStart = NULL; dctx->dictEnd = NULL; - dctx->entropy.hufTable[0] = (HUF_DTable)((HufLog)*0x1000001); /* cover both little and big endian */ + dctx->entropy.hufTable[0] = (HUF_DTable)((ZSTD_HUFFDTABLE_CAPACITY_LOG)*0x1000001); /* cover both little and big endian */ dctx->litEntropy = dctx->fseEntropy = 0; dctx->dictID = 0; dctx->bType = bt_reserved; + dctx->isFrameDecompression = 1; ZSTD_STATIC_ASSERT(sizeof(dctx->entropy.rep) == sizeof(repStartValue)); ZSTD_memcpy(dctx->entropy.rep, repStartValue, sizeof(repStartValue)); /* initial repcodes */ dctx->LLTptr = dctx->entropy.LLTable; @@ -1465,7 +1557,7 @@ unsigned ZSTD_getDictID_fromDict(const void* dict, size_t dictSize) * This could for one of the following reasons : * - The frame does not require a dictionary (most common case). * - The frame was built with dictID intentionally removed. - * Needed dictionary is a hidden information. + * Needed dictionary is a hidden piece of information. * Note : this use case also happens when using a non-conformant dictionary. * - `srcSize` is too small, and as a result, frame header could not be decoded. * Note : possible if `srcSize < ZSTD_FRAMEHEADERSIZE_MAX`. @@ -1474,7 +1566,7 @@ unsigned ZSTD_getDictID_fromDict(const void* dict, size_t dictSize) * ZSTD_getFrameHeader(), which will provide a more precise error code. */ unsigned ZSTD_getDictID_fromFrame(const void* src, size_t srcSize) { - ZSTD_frameHeader zfp = { 0, 0, 0, ZSTD_frame, 0, 0, 0 }; + ZSTD_FrameHeader zfp = { 0, 0, 0, ZSTD_frame, 0, 0, 0, 0, 0 }; size_t const hError = ZSTD_getFrameHeader(&zfp, src, srcSize); if (ZSTD_isError(hError)) return 0; return zfp.dictID; @@ -1581,7 +1673,9 @@ size_t ZSTD_initDStream_usingDict(ZSTD_DStream* zds, const void* dict, size_t di size_t ZSTD_initDStream(ZSTD_DStream* zds) { DEBUGLOG(4, "ZSTD_initDStream"); - return ZSTD_initDStream_usingDDict(zds, NULL); + FORWARD_IF_ERROR(ZSTD_DCtx_reset(zds, ZSTD_reset_session_only), ""); + FORWARD_IF_ERROR(ZSTD_DCtx_refDDict(zds, NULL), ""); + return ZSTD_startingInputLength(zds->format); } /* ZSTD_initDStream_usingDDict() : @@ -1589,6 +1683,7 @@ size_t ZSTD_initDStream(ZSTD_DStream* zds) * this function cannot fail */ size_t ZSTD_initDStream_usingDDict(ZSTD_DStream* dctx, const ZSTD_DDict* ddict) { + DEBUGLOG(4, "ZSTD_initDStream_usingDDict"); FORWARD_IF_ERROR( ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only) , ""); FORWARD_IF_ERROR( ZSTD_DCtx_refDDict(dctx, ddict) , ""); return ZSTD_startingInputLength(dctx->format); @@ -1599,6 +1694,7 @@ size_t ZSTD_initDStream_usingDDict(ZSTD_DStream* dctx, const ZSTD_DDict* ddict) * this function cannot fail */ size_t ZSTD_resetDStream(ZSTD_DStream* dctx) { + DEBUGLOG(4, "ZSTD_resetDStream"); FORWARD_IF_ERROR(ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only), ""); return ZSTD_startingInputLength(dctx->format); } @@ -1670,6 +1766,15 @@ ZSTD_bounds ZSTD_dParam_getBounds(ZSTD_dParameter dParam) bounds.lowerBound = (int)ZSTD_rmd_refSingleDDict; bounds.upperBound = (int)ZSTD_rmd_refMultipleDDicts; return bounds; + case ZSTD_d_disableHuffmanAssembly: + bounds.lowerBound = 0; + bounds.upperBound = 1; + return bounds; + case ZSTD_d_maxBlockSize: + bounds.lowerBound = ZSTD_BLOCKSIZE_MAX_MIN; + bounds.upperBound = ZSTD_BLOCKSIZE_MAX; + return bounds; + default:; } bounds.error = ERROR(parameter_unsupported); @@ -1710,6 +1815,12 @@ size_t ZSTD_DCtx_getParameter(ZSTD_DCtx* dctx, ZSTD_dParameter param, int* value case ZSTD_d_refMultipleDDicts: *value = (int)dctx->refMultipleDDicts; return 0; + case ZSTD_d_disableHuffmanAssembly: + *value = (int)dctx->disableHufAsm; + return 0; + case ZSTD_d_maxBlockSize: + *value = dctx->maxBlockSizeParam; + return 0; default:; } RETURN_ERROR(parameter_unsupported, ""); @@ -1743,6 +1854,14 @@ size_t ZSTD_DCtx_setParameter(ZSTD_DCtx* dctx, ZSTD_dParameter dParam, int value } dctx->refMultipleDDicts = (ZSTD_refMultipleDDicts_e)value; return 0; + case ZSTD_d_disableHuffmanAssembly: + CHECK_DBOUNDS(ZSTD_d_disableHuffmanAssembly, value); + dctx->disableHufAsm = value != 0; + return 0; + case ZSTD_d_maxBlockSize: + if (value != 0) CHECK_DBOUNDS(ZSTD_d_maxBlockSize, value); + dctx->maxBlockSizeParam = value; + return 0; default:; } RETURN_ERROR(parameter_unsupported, ""); @@ -1754,6 +1873,7 @@ size_t ZSTD_DCtx_reset(ZSTD_DCtx* dctx, ZSTD_ResetDirective reset) || (reset == ZSTD_reset_session_and_parameters) ) { dctx->streamStage = zdss_init; dctx->noForwardProgress = 0; + dctx->isFrameDecompression = 1; } if ( (reset == ZSTD_reset_parameters) || (reset == ZSTD_reset_session_and_parameters) ) { @@ -1770,11 +1890,17 @@ size_t ZSTD_sizeof_DStream(const ZSTD_DStream* dctx) return ZSTD_sizeof_DCtx(dctx); } -size_t ZSTD_decodingBufferSize_min(unsigned long long windowSize, unsigned long long frameContentSize) +static size_t ZSTD_decodingBufferSize_internal(unsigned long long windowSize, unsigned long long frameContentSize, size_t blockSizeMax) { - size_t const blockSize = (size_t) MIN(windowSize, ZSTD_BLOCKSIZE_MAX); - /* space is needed to store the litbuffer after the output of a given block without stomping the extDict of a previous run, as well as to cover both windows against wildcopy*/ - unsigned long long const neededRBSize = windowSize + blockSize + ZSTD_BLOCKSIZE_MAX + (WILDCOPY_OVERLENGTH * 2); + size_t const blockSize = MIN((size_t)MIN(windowSize, ZSTD_BLOCKSIZE_MAX), blockSizeMax); + /* We need blockSize + WILDCOPY_OVERLENGTH worth of buffer so that if a block + * ends at windowSize + WILDCOPY_OVERLENGTH + 1 bytes, we can start writing + * the block at the beginning of the output buffer, and maintain a full window. + * + * We need another blockSize worth of buffer so that we can store split + * literals at the end of the block without overwriting the extDict window. + */ + unsigned long long const neededRBSize = windowSize + (blockSize * 2) + (WILDCOPY_OVERLENGTH * 2); unsigned long long const neededSize = MIN(frameContentSize, neededRBSize); size_t const minRBSize = (size_t) neededSize; RETURN_ERROR_IF((unsigned long long)minRBSize != neededSize, @@ -1782,6 +1908,11 @@ size_t ZSTD_decodingBufferSize_min(unsigned long long windowSize, unsigned long return minRBSize; } +size_t ZSTD_decodingBufferSize_min(unsigned long long windowSize, unsigned long long frameContentSize) +{ + return ZSTD_decodingBufferSize_internal(windowSize, frameContentSize, ZSTD_BLOCKSIZE_MAX); +} + size_t ZSTD_estimateDStreamSize(size_t windowSize) { size_t const blockSize = MIN(windowSize, ZSTD_BLOCKSIZE_MAX); @@ -1793,7 +1924,7 @@ size_t ZSTD_estimateDStreamSize(size_t windowSize) size_t ZSTD_estimateDStreamSize_fromFrame(const void* src, size_t srcSize) { U32 const windowSizeMax = 1U << ZSTD_WINDOWLOG_MAX; /* note : should be user-selectable, but requires an additional parameter (or a dctx) */ - ZSTD_frameHeader zfh; + ZSTD_FrameHeader zfh; size_t const err = ZSTD_getFrameHeader(&zfh, src, srcSize); if (ZSTD_isError(err)) return err; RETURN_ERROR_IF(err>0, srcSize_wrong, ""); @@ -1888,6 +2019,7 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB U32 someMoreWork = 1; DEBUGLOG(5, "ZSTD_decompressStream"); + assert(zds != NULL); RETURN_ERROR_IF( input->pos > input->size, srcSize_wrong, @@ -1918,7 +2050,6 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB if (zds->refMultipleDDicts && zds->ddictSet) { ZSTD_DCtx_selectFrameDDict(zds); } - DEBUGLOG(5, "header size : %u", (U32)hSize); if (ZSTD_isError(hSize)) { return hSize; /* error */ } @@ -1932,6 +2063,11 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB zds->lhSize += remainingInput; } input->pos = input->size; + /* check first few bytes */ + FORWARD_IF_ERROR( + ZSTD_getFrameHeader_advanced(&zds->fParams, zds->headerBuffer, zds->lhSize, zds->format), + "First few bytes detected incorrect" ); + /* return hint input size */ return (MAX((size_t)ZSTD_FRAMEHEADERSIZE_MIN(zds->format), hSize) - zds->lhSize) + ZSTD_blockHeaderSize; /* remaining header bytes + next block header */ } assert(ip != NULL); @@ -1943,14 +2079,15 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB if (zds->fParams.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN && zds->fParams.frameType != ZSTD_skippableFrame && (U64)(size_t)(oend-op) >= zds->fParams.frameContentSize) { - size_t const cSize = ZSTD_findFrameCompressedSize(istart, (size_t)(iend-istart)); + size_t const cSize = ZSTD_findFrameCompressedSize_advanced(istart, (size_t)(iend-istart), zds->format); if (cSize <= (size_t)(iend-istart)) { /* shortcut : using single-pass mode */ size_t const decompressedSize = ZSTD_decompress_usingDDict(zds, op, (size_t)(oend-op), istart, cSize, ZSTD_getDDict(zds)); if (ZSTD_isError(decompressedSize)) return decompressedSize; - DEBUGLOG(4, "shortcut to single-pass ZSTD_decompress_usingDDict()") + DEBUGLOG(4, "shortcut to single-pass ZSTD_decompress_usingDDict()"); + assert(istart != NULL); ip = istart + cSize; - op += decompressedSize; + op = op ? op + decompressedSize : op; /* can occur if frameContentSize = 0 (empty frame) */ zds->expected = 0; zds->streamStage = zdss_init; someMoreWork = 0; @@ -1969,7 +2106,8 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB DEBUGLOG(4, "Consume header"); FORWARD_IF_ERROR(ZSTD_decompressBegin_usingDDict(zds, ZSTD_getDDict(zds)), ""); - if ((MEM_readLE32(zds->headerBuffer) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) { /* skippable frame */ + if (zds->format == ZSTD_f_zstd1 + && (MEM_readLE32(zds->headerBuffer) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) { /* skippable frame */ zds->expected = MEM_readLE32(zds->headerBuffer + ZSTD_FRAMEIDSIZE); zds->stage = ZSTDds_skipFrame; } else { @@ -1985,11 +2123,13 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB zds->fParams.windowSize = MAX(zds->fParams.windowSize, 1U << ZSTD_WINDOWLOG_ABSOLUTEMIN); RETURN_ERROR_IF(zds->fParams.windowSize > zds->maxWindowSize, frameParameter_windowTooLarge, ""); + if (zds->maxBlockSizeParam != 0) + zds->fParams.blockSizeMax = MIN(zds->fParams.blockSizeMax, (unsigned)zds->maxBlockSizeParam); /* Adapt buffer sizes to frame header instructions */ { size_t const neededInBuffSize = MAX(zds->fParams.blockSizeMax, 4 /* frame checksum */); size_t const neededOutBuffSize = zds->outBufferMode == ZSTD_bm_buffered - ? ZSTD_decodingBufferSize_min(zds->fParams.windowSize, zds->fParams.frameContentSize) + ? ZSTD_decodingBufferSize_internal(zds->fParams.windowSize, zds->fParams.frameContentSize, zds->fParams.blockSizeMax) : 0; ZSTD_DCtx_updateOversizedDuration(zds, neededInBuffSize, neededOutBuffSize); @@ -2034,6 +2174,7 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB } if ((size_t)(iend-ip) >= neededInSize) { /* decode directly from src */ FORWARD_IF_ERROR(ZSTD_decompressContinueStream(zds, &op, oend, ip, neededInSize), ""); + assert(ip != NULL); ip += neededInSize; /* Function modifies the stage so we must break */ break; @@ -2048,7 +2189,7 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB int const isSkipFrame = ZSTD_isSkipFrame(zds); size_t loadedSize; /* At this point we shouldn't be decompressing a block that we can stream. */ - assert(neededInSize == ZSTD_nextSrcSizeToDecompressWithInputSize(zds, iend - ip)); + assert(neededInSize == ZSTD_nextSrcSizeToDecompressWithInputSize(zds, (size_t)(iend - ip))); if (isSkipFrame) { loadedSize = MIN(toLoad, (size_t)(iend-ip)); } else { @@ -2057,8 +2198,11 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB "should never happen"); loadedSize = ZSTD_limitCopy(zds->inBuff + zds->inPos, toLoad, ip, (size_t)(iend-ip)); } - ip += loadedSize; - zds->inPos += loadedSize; + if (loadedSize != 0) { + /* ip may be NULL */ + ip += loadedSize; + zds->inPos += loadedSize; + } if (loadedSize < toLoad) { someMoreWork = 0; break; } /* not enough input, wait for more */ /* decode loaded input */ @@ -2068,14 +2212,17 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB break; } case zdss_flush: - { size_t const toFlushSize = zds->outEnd - zds->outStart; + { + size_t const toFlushSize = zds->outEnd - zds->outStart; size_t const flushedSize = ZSTD_limitCopy(op, (size_t)(oend-op), zds->outBuff + zds->outStart, toFlushSize); - op += flushedSize; + + op = op ? op + flushedSize : op; + zds->outStart += flushedSize; if (flushedSize == toFlushSize) { /* flush completed */ zds->streamStage = zdss_read; if ( (zds->outBuffSize < zds->fParams.frameContentSize) - && (zds->outStart + zds->fParams.blockSizeMax > zds->outBuffSize) ) { + && (zds->outStart + zds->fParams.blockSizeMax > zds->outBuffSize) ) { DEBUGLOG(5, "restart filling outBuff from beginning (left:%i, needed:%u)", (int)(zds->outBuffSize - zds->outStart), (U32)zds->fParams.blockSizeMax); @@ -2089,7 +2236,7 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB default: assert(0); /* impossible */ - RETURN_ERROR(GENERIC, "impossible to reach"); /* some compiler require default to do something */ + RETURN_ERROR(GENERIC, "impossible to reach"); /* some compilers require default to do something */ } } /* result */ @@ -2102,8 +2249,8 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB if ((ip==istart) && (op==ostart)) { /* no forward progress */ zds->noForwardProgress ++; if (zds->noForwardProgress >= ZSTD_NO_FORWARD_PROGRESS_MAX) { - RETURN_ERROR_IF(op==oend, dstSize_tooSmall, ""); - RETURN_ERROR_IF(ip==iend, srcSize_wrong, ""); + RETURN_ERROR_IF(op==oend, noForwardProgress_destFull, ""); + RETURN_ERROR_IF(ip==iend, noForwardProgress_inputEmpty, ""); assert(0); } } else { @@ -2140,11 +2287,17 @@ size_t ZSTD_decompressStream_simpleArgs ( void* dst, size_t dstCapacity, size_t* dstPos, const void* src, size_t srcSize, size_t* srcPos) { - ZSTD_outBuffer output = { dst, dstCapacity, *dstPos }; - ZSTD_inBuffer input = { src, srcSize, *srcPos }; - /* ZSTD_compress_generic() will check validity of dstPos and srcPos */ - size_t const cErr = ZSTD_decompressStream(dctx, &output, &input); - *dstPos = output.pos; - *srcPos = input.pos; - return cErr; + ZSTD_outBuffer output; + ZSTD_inBuffer input; + output.dst = dst; + output.size = dstCapacity; + output.pos = *dstPos; + input.src = src; + input.size = srcSize; + input.pos = *srcPos; + { size_t const cErr = ZSTD_decompressStream(dctx, &output, &input); + *dstPos = output.pos; + *srcPos = input.pos; + return cErr; + } } diff --git a/lib/zstd/decompress/zstd_decompress_block.c b/lib/zstd/decompress/zstd_decompress_block.c index c1913b8e7c89..710eb0ffd5a3 100644 --- a/lib/zstd/decompress/zstd_decompress_block.c +++ b/lib/zstd/decompress/zstd_decompress_block.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -20,12 +21,12 @@ #include "../common/mem.h" /* low level memory routines */ #define FSE_STATIC_LINKING_ONLY #include "../common/fse.h" -#define HUF_STATIC_LINKING_ONLY #include "../common/huf.h" #include "../common/zstd_internal.h" #include "zstd_decompress_internal.h" /* ZSTD_DCtx */ #include "zstd_ddict.h" /* ZSTD_DDictDictContent */ #include "zstd_decompress_block.h" +#include "../common/bits.h" /* ZSTD_highbit32 */ /*_******************************************************* * Macros @@ -51,6 +52,13 @@ static void ZSTD_copy4(void* dst, const void* src) { ZSTD_memcpy(dst, src, 4); } * Block decoding ***************************************************************/ +static size_t ZSTD_blockSizeMax(ZSTD_DCtx const* dctx) +{ + size_t const blockSizeMax = dctx->isFrameDecompression ? dctx->fParams.blockSizeMax : ZSTD_BLOCKSIZE_MAX; + assert(blockSizeMax <= ZSTD_BLOCKSIZE_MAX); + return blockSizeMax; +} + /*! ZSTD_getcBlockSize() : * Provides the size of compressed block from block header `src` */ size_t ZSTD_getcBlockSize(const void* src, size_t srcSize, @@ -73,41 +81,49 @@ size_t ZSTD_getcBlockSize(const void* src, size_t srcSize, static void ZSTD_allocateLiteralsBuffer(ZSTD_DCtx* dctx, void* const dst, const size_t dstCapacity, const size_t litSize, const streaming_operation streaming, const size_t expectedWriteSize, const unsigned splitImmediately) { - if (streaming == not_streaming && dstCapacity > ZSTD_BLOCKSIZE_MAX + WILDCOPY_OVERLENGTH + litSize + WILDCOPY_OVERLENGTH) - { - /* room for litbuffer to fit without read faulting */ - dctx->litBuffer = (BYTE*)dst + ZSTD_BLOCKSIZE_MAX + WILDCOPY_OVERLENGTH; + size_t const blockSizeMax = ZSTD_blockSizeMax(dctx); + assert(litSize <= blockSizeMax); + assert(dctx->isFrameDecompression || streaming == not_streaming); + assert(expectedWriteSize <= blockSizeMax); + if (streaming == not_streaming && dstCapacity > blockSizeMax + WILDCOPY_OVERLENGTH + litSize + WILDCOPY_OVERLENGTH) { + /* If we aren't streaming, we can just put the literals after the output + * of the current block. We don't need to worry about overwriting the + * extDict of our window, because it doesn't exist. + * So if we have space after the end of the block, just put it there. + */ + dctx->litBuffer = (BYTE*)dst + blockSizeMax + WILDCOPY_OVERLENGTH; dctx->litBufferEnd = dctx->litBuffer + litSize; dctx->litBufferLocation = ZSTD_in_dst; - } - else if (litSize > ZSTD_LITBUFFEREXTRASIZE) - { - /* won't fit in litExtraBuffer, so it will be split between end of dst and extra buffer */ + } else if (litSize <= ZSTD_LITBUFFEREXTRASIZE) { + /* Literals fit entirely within the extra buffer, put them there to avoid + * having to split the literals. + */ + dctx->litBuffer = dctx->litExtraBuffer; + dctx->litBufferEnd = dctx->litBuffer + litSize; + dctx->litBufferLocation = ZSTD_not_in_dst; + } else { + assert(blockSizeMax > ZSTD_LITBUFFEREXTRASIZE); + /* Literals must be split between the output block and the extra lit + * buffer. We fill the extra lit buffer with the tail of the literals, + * and put the rest of the literals at the end of the block, with + * WILDCOPY_OVERLENGTH of buffer room to allow for overreads. + * This MUST not write more than our maxBlockSize beyond dst, because in + * streaming mode, that could overwrite part of our extDict window. + */ if (splitImmediately) { /* won't fit in litExtraBuffer, so it will be split between end of dst and extra buffer */ dctx->litBuffer = (BYTE*)dst + expectedWriteSize - litSize + ZSTD_LITBUFFEREXTRASIZE - WILDCOPY_OVERLENGTH; dctx->litBufferEnd = dctx->litBuffer + litSize - ZSTD_LITBUFFEREXTRASIZE; - } - else { - /* initially this will be stored entirely in dst during huffman decoding, it will partially shifted to litExtraBuffer after */ + } else { + /* initially this will be stored entirely in dst during huffman decoding, it will partially be shifted to litExtraBuffer after */ dctx->litBuffer = (BYTE*)dst + expectedWriteSize - litSize; dctx->litBufferEnd = (BYTE*)dst + expectedWriteSize; } dctx->litBufferLocation = ZSTD_split; - } - else - { - /* fits entirely within litExtraBuffer, so no split is necessary */ - dctx->litBuffer = dctx->litExtraBuffer; - dctx->litBufferEnd = dctx->litBuffer + litSize; - dctx->litBufferLocation = ZSTD_not_in_dst; + assert(dctx->litBufferEnd <= (BYTE*)dst + expectedWriteSize); } } -/* Hidden declaration for fullbench */ -size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, - const void* src, size_t srcSize, - void* dst, size_t dstCapacity, const streaming_operation streaming); /*! ZSTD_decodeLiteralsBlock() : * Where it is possible to do so without being stomped by the output during decompression, the literals block will be stored * in the dstBuffer. If there is room to do so, it will be stored in full in the excess dst space after where the current @@ -116,7 +132,7 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, * * @return : nb of bytes read from src (< srcSize ) * note : symbol not declared but exposed for fullbench */ -size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, +static size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, const void* src, size_t srcSize, /* note : srcSize < BLOCKSIZE */ void* dst, size_t dstCapacity, const streaming_operation streaming) { @@ -124,7 +140,8 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, RETURN_ERROR_IF(srcSize < MIN_CBLOCK_SIZE, corruption_detected, ""); { const BYTE* const istart = (const BYTE*) src; - symbolEncodingType_e const litEncType = (symbolEncodingType_e)(istart[0] & 3); + SymbolEncodingType_e const litEncType = (SymbolEncodingType_e)(istart[0] & 3); + size_t const blockSizeMax = ZSTD_blockSizeMax(dctx); switch(litEncType) { @@ -134,13 +151,16 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, ZSTD_FALLTHROUGH; case set_compressed: - RETURN_ERROR_IF(srcSize < 5, corruption_detected, "srcSize >= MIN_CBLOCK_SIZE == 3; here we need up to 5 for case 3"); + RETURN_ERROR_IF(srcSize < 5, corruption_detected, "srcSize >= MIN_CBLOCK_SIZE == 2; here we need up to 5 for case 3"); { size_t lhSize, litSize, litCSize; U32 singleStream=0; U32 const lhlCode = (istart[0] >> 2) & 3; U32 const lhc = MEM_readLE32(istart); size_t hufSuccess; - size_t expectedWriteSize = MIN(ZSTD_BLOCKSIZE_MAX, dstCapacity); + size_t expectedWriteSize = MIN(blockSizeMax, dstCapacity); + int const flags = 0 + | (ZSTD_DCtx_get_bmi2(dctx) ? HUF_flags_bmi2 : 0) + | (dctx->disableHufAsm ? HUF_flags_disableAsm : 0); switch(lhlCode) { case 0: case 1: default: /* note : default is impossible, since lhlCode into [0..3] */ @@ -164,7 +184,11 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, break; } RETURN_ERROR_IF(litSize > 0 && dst == NULL, dstSize_tooSmall, "NULL not handled"); - RETURN_ERROR_IF(litSize > ZSTD_BLOCKSIZE_MAX, corruption_detected, ""); + RETURN_ERROR_IF(litSize > blockSizeMax, corruption_detected, ""); + if (!singleStream) + RETURN_ERROR_IF(litSize < MIN_LITERALS_FOR_4_STREAMS, literals_headerWrong, + "Not enough literals (%zu) for the 4-streams mode (min %u)", + litSize, MIN_LITERALS_FOR_4_STREAMS); RETURN_ERROR_IF(litCSize + lhSize > srcSize, corruption_detected, ""); RETURN_ERROR_IF(expectedWriteSize < litSize , dstSize_tooSmall, ""); ZSTD_allocateLiteralsBuffer(dctx, dst, dstCapacity, litSize, streaming, expectedWriteSize, 0); @@ -176,13 +200,14 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, if (litEncType==set_repeat) { if (singleStream) { - hufSuccess = HUF_decompress1X_usingDTable_bmi2( + hufSuccess = HUF_decompress1X_usingDTable( dctx->litBuffer, litSize, istart+lhSize, litCSize, - dctx->HUFptr, ZSTD_DCtx_get_bmi2(dctx)); + dctx->HUFptr, flags); } else { - hufSuccess = HUF_decompress4X_usingDTable_bmi2( + assert(litSize >= MIN_LITERALS_FOR_4_STREAMS); + hufSuccess = HUF_decompress4X_usingDTable( dctx->litBuffer, litSize, istart+lhSize, litCSize, - dctx->HUFptr, ZSTD_DCtx_get_bmi2(dctx)); + dctx->HUFptr, flags); } } else { if (singleStream) { @@ -190,26 +215,28 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, hufSuccess = HUF_decompress1X_DCtx_wksp( dctx->entropy.hufTable, dctx->litBuffer, litSize, istart+lhSize, litCSize, dctx->workspace, - sizeof(dctx->workspace)); + sizeof(dctx->workspace), flags); #else - hufSuccess = HUF_decompress1X1_DCtx_wksp_bmi2( + hufSuccess = HUF_decompress1X1_DCtx_wksp( dctx->entropy.hufTable, dctx->litBuffer, litSize, istart+lhSize, litCSize, dctx->workspace, - sizeof(dctx->workspace), ZSTD_DCtx_get_bmi2(dctx)); + sizeof(dctx->workspace), flags); #endif } else { - hufSuccess = HUF_decompress4X_hufOnly_wksp_bmi2( + hufSuccess = HUF_decompress4X_hufOnly_wksp( dctx->entropy.hufTable, dctx->litBuffer, litSize, istart+lhSize, litCSize, dctx->workspace, - sizeof(dctx->workspace), ZSTD_DCtx_get_bmi2(dctx)); + sizeof(dctx->workspace), flags); } } if (dctx->litBufferLocation == ZSTD_split) { + assert(litSize > ZSTD_LITBUFFEREXTRASIZE); ZSTD_memcpy(dctx->litExtraBuffer, dctx->litBufferEnd - ZSTD_LITBUFFEREXTRASIZE, ZSTD_LITBUFFEREXTRASIZE); ZSTD_memmove(dctx->litBuffer + ZSTD_LITBUFFEREXTRASIZE - WILDCOPY_OVERLENGTH, dctx->litBuffer, litSize - ZSTD_LITBUFFEREXTRASIZE); dctx->litBuffer += ZSTD_LITBUFFEREXTRASIZE - WILDCOPY_OVERLENGTH; dctx->litBufferEnd -= WILDCOPY_OVERLENGTH; + assert(dctx->litBufferEnd <= (BYTE*)dst + blockSizeMax); } RETURN_ERROR_IF(HUF_isError(hufSuccess), corruption_detected, ""); @@ -224,7 +251,7 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, case set_basic: { size_t litSize, lhSize; U32 const lhlCode = ((istart[0]) >> 2) & 3; - size_t expectedWriteSize = MIN(ZSTD_BLOCKSIZE_MAX, dstCapacity); + size_t expectedWriteSize = MIN(blockSizeMax, dstCapacity); switch(lhlCode) { case 0: case 2: default: /* note : default is impossible, since lhlCode into [0..3] */ @@ -237,11 +264,13 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, break; case 3: lhSize = 3; + RETURN_ERROR_IF(srcSize<3, corruption_detected, "srcSize >= MIN_CBLOCK_SIZE == 2; here we need lhSize = 3"); litSize = MEM_readLE24(istart) >> 4; break; } RETURN_ERROR_IF(litSize > 0 && dst == NULL, dstSize_tooSmall, "NULL not handled"); + RETURN_ERROR_IF(litSize > blockSizeMax, corruption_detected, ""); RETURN_ERROR_IF(expectedWriteSize < litSize, dstSize_tooSmall, ""); ZSTD_allocateLiteralsBuffer(dctx, dst, dstCapacity, litSize, streaming, expectedWriteSize, 1); if (lhSize+litSize+WILDCOPY_OVERLENGTH > srcSize) { /* risk reading beyond src buffer with wildcopy */ @@ -270,7 +299,7 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, case set_rle: { U32 const lhlCode = ((istart[0]) >> 2) & 3; size_t litSize, lhSize; - size_t expectedWriteSize = MIN(ZSTD_BLOCKSIZE_MAX, dstCapacity); + size_t expectedWriteSize = MIN(blockSizeMax, dstCapacity); switch(lhlCode) { case 0: case 2: default: /* note : default is impossible, since lhlCode into [0..3] */ @@ -279,16 +308,17 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, break; case 1: lhSize = 2; + RETURN_ERROR_IF(srcSize<3, corruption_detected, "srcSize >= MIN_CBLOCK_SIZE == 2; here we need lhSize+1 = 3"); litSize = MEM_readLE16(istart) >> 4; break; case 3: lhSize = 3; + RETURN_ERROR_IF(srcSize<4, corruption_detected, "srcSize >= MIN_CBLOCK_SIZE == 2; here we need lhSize+1 = 4"); litSize = MEM_readLE24(istart) >> 4; - RETURN_ERROR_IF(srcSize<4, corruption_detected, "srcSize >= MIN_CBLOCK_SIZE == 3; here we need lhSize+1 = 4"); break; } RETURN_ERROR_IF(litSize > 0 && dst == NULL, dstSize_tooSmall, "NULL not handled"); - RETURN_ERROR_IF(litSize > ZSTD_BLOCKSIZE_MAX, corruption_detected, ""); + RETURN_ERROR_IF(litSize > blockSizeMax, corruption_detected, ""); RETURN_ERROR_IF(expectedWriteSize < litSize, dstSize_tooSmall, ""); ZSTD_allocateLiteralsBuffer(dctx, dst, dstCapacity, litSize, streaming, expectedWriteSize, 1); if (dctx->litBufferLocation == ZSTD_split) @@ -310,6 +340,18 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, } } +/* Hidden declaration for fullbench */ +size_t ZSTD_decodeLiteralsBlock_wrapper(ZSTD_DCtx* dctx, + const void* src, size_t srcSize, + void* dst, size_t dstCapacity); +size_t ZSTD_decodeLiteralsBlock_wrapper(ZSTD_DCtx* dctx, + const void* src, size_t srcSize, + void* dst, size_t dstCapacity) +{ + dctx->isFrameDecompression = 0; + return ZSTD_decodeLiteralsBlock(dctx, src, srcSize, dst, dstCapacity, not_streaming); +} + /* Default FSE distribution tables. * These are pre-calculated FSE decoding tables using default distributions as defined in specification : * https://github.com/facebook/zstd/blob/release/doc/zstd_compression_format.md#default-distributions @@ -317,7 +359,7 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, * - start from default distributions, present in /lib/common/zstd_internal.h * - generate tables normally, using ZSTD_buildFSETable() * - printout the content of tables - * - pretify output, report below, test with fuzzer to ensure it's correct */ + * - prettify output, report below, test with fuzzer to ensure it's correct */ /* Default FSE distribution table for Literal Lengths */ static const ZSTD_seqSymbol LL_defaultDTable[(1<=0); + pos += (size_t)n; } } /* Now we spread those positions across the table. - * The benefit of doing it in two stages is that we avoid the the + * The benefit of doing it in two stages is that we avoid the * variable size inner loop, which caused lots of branch misses. * Now we can run through all the positions without any branch misses. - * We unroll the loop twice, since that is what emperically worked best. + * We unroll the loop twice, since that is what empirically worked best. */ { size_t position = 0; @@ -540,7 +583,7 @@ void ZSTD_buildFSETable_body(ZSTD_seqSymbol* dt, for (i=0; i highThreshold) position = (position + step) & tableMask; /* lowprob area */ + while (UNLIKELY(position > highThreshold)) position = (position + step) & tableMask; /* lowprob area */ } } assert(position == 0); /* position must reach all cells once, otherwise normalizedCounter is incorrect */ } @@ -551,7 +594,7 @@ void ZSTD_buildFSETable_body(ZSTD_seqSymbol* dt, for (u=0; u 0x7F) { if (nbSeq == 0xFF) { RETURN_ERROR_IF(ip+2 > iend, srcSize_wrong, ""); @@ -681,11 +719,19 @@ size_t ZSTD_decodeSeqHeaders(ZSTD_DCtx* dctx, int* nbSeqPtr, } *nbSeqPtr = nbSeq; + if (nbSeq == 0) { + /* No sequence : section ends immediately */ + RETURN_ERROR_IF(ip != iend, corruption_detected, + "extraneous data present in the Sequences section"); + return (size_t)(ip - istart); + } + /* FSE table descriptors */ RETURN_ERROR_IF(ip+1 > iend, srcSize_wrong, ""); /* minimum possible size: 1 byte for symbol encoding types */ - { symbolEncodingType_e const LLtype = (symbolEncodingType_e)(*ip >> 6); - symbolEncodingType_e const OFtype = (symbolEncodingType_e)((*ip >> 4) & 3); - symbolEncodingType_e const MLtype = (symbolEncodingType_e)((*ip >> 2) & 3); + RETURN_ERROR_IF(*ip & 3, corruption_detected, ""); /* The last field, Reserved, must be all-zeroes. */ + { SymbolEncodingType_e const LLtype = (SymbolEncodingType_e)(*ip >> 6); + SymbolEncodingType_e const OFtype = (SymbolEncodingType_e)((*ip >> 4) & 3); + SymbolEncodingType_e const MLtype = (SymbolEncodingType_e)((*ip >> 2) & 3); ip++; /* Build DTables */ @@ -829,7 +875,7 @@ static void ZSTD_safecopy(BYTE* op, const BYTE* const oend_w, BYTE const* ip, pt /* ZSTD_safecopyDstBeforeSrc(): * This version allows overlap with dst before src, or handles the non-overlap case with dst after src * Kept separate from more common ZSTD_safecopy case to avoid performance impact to the safecopy common case */ -static void ZSTD_safecopyDstBeforeSrc(BYTE* op, BYTE const* ip, ptrdiff_t length) { +static void ZSTD_safecopyDstBeforeSrc(BYTE* op, const BYTE* ip, ptrdiff_t length) { ptrdiff_t const diff = op - ip; BYTE* const oend = op + length; @@ -858,6 +904,7 @@ static void ZSTD_safecopyDstBeforeSrc(BYTE* op, BYTE const* ip, ptrdiff_t length * to be optimized for many small sequences, since those fall into ZSTD_execSequence(). */ FORCE_NOINLINE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR size_t ZSTD_execSequenceEnd(BYTE* op, BYTE* const oend, seq_t sequence, const BYTE** litPtr, const BYTE* const litLimit, @@ -905,6 +952,7 @@ size_t ZSTD_execSequenceEnd(BYTE* op, * This version is intended to be used during instances where the litBuffer is still split. It is kept separate to avoid performance impact for the good case. */ FORCE_NOINLINE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR size_t ZSTD_execSequenceEndSplitLitBuffer(BYTE* op, BYTE* const oend, const BYTE* const oend_w, seq_t sequence, const BYTE** litPtr, const BYTE* const litLimit, @@ -950,6 +998,7 @@ size_t ZSTD_execSequenceEndSplitLitBuffer(BYTE* op, } HINT_INLINE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR size_t ZSTD_execSequence(BYTE* op, BYTE* const oend, seq_t sequence, const BYTE** litPtr, const BYTE* const litLimit, @@ -964,6 +1013,11 @@ size_t ZSTD_execSequence(BYTE* op, assert(op != NULL /* Precondition */); assert(oend_w < oend /* No underflow */); + +#if defined(__aarch64__) + /* prefetch sequence starting from match that will be used for copy later */ + PREFETCH_L1(match); +#endif /* Handle edge cases in a slow path: * - Read beyond end of literals * - Match end is within WILDCOPY_OVERLIMIT of oend @@ -1043,6 +1097,7 @@ size_t ZSTD_execSequence(BYTE* op, } HINT_INLINE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR size_t ZSTD_execSequenceSplitLitBuffer(BYTE* op, BYTE* const oend, const BYTE* const oend_w, seq_t sequence, const BYTE** litPtr, const BYTE* const litLimit, @@ -1154,7 +1209,7 @@ ZSTD_updateFseStateWithDInfo(ZSTD_fseState* DStatePtr, BIT_DStream_t* bitD, U16 } /* We need to add at most (ZSTD_WINDOWLOG_MAX_32 - 1) bits to read the maximum - * offset bits. But we can only read at most (STREAM_ACCUMULATOR_MIN_32 - 1) + * offset bits. But we can only read at most STREAM_ACCUMULATOR_MIN_32 * bits before reloading. This value is the maximum number of bytes we read * after reloading when we are decoding long offsets. */ @@ -1165,13 +1220,37 @@ ZSTD_updateFseStateWithDInfo(ZSTD_fseState* DStatePtr, BIT_DStream_t* bitD, U16 typedef enum { ZSTD_lo_isRegularOffset, ZSTD_lo_isLongOffset=1 } ZSTD_longOffset_e; +/* + * ZSTD_decodeSequence(): + * @p longOffsets : tells the decoder to reload more bit while decoding large offsets + * only used in 32-bit mode + * @return : Sequence (litL + matchL + offset) + */ FORCE_INLINE_TEMPLATE seq_t -ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets) +ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets, const int isLastSeq) { seq_t seq; + /* + * ZSTD_seqSymbol is a 64 bits wide structure. + * It can be loaded in one operation + * and its fields extracted by simply shifting or bit-extracting on aarch64. + * GCC doesn't recognize this and generates more unnecessary ldr/ldrb/ldrh + * operations that cause performance drop. This can be avoided by using this + * ZSTD_memcpy hack. + */ +#if defined(__aarch64__) && (defined(__GNUC__) && !defined(__clang__)) + ZSTD_seqSymbol llDInfoS, mlDInfoS, ofDInfoS; + ZSTD_seqSymbol* const llDInfo = &llDInfoS; + ZSTD_seqSymbol* const mlDInfo = &mlDInfoS; + ZSTD_seqSymbol* const ofDInfo = &ofDInfoS; + ZSTD_memcpy(llDInfo, seqState->stateLL.table + seqState->stateLL.state, sizeof(ZSTD_seqSymbol)); + ZSTD_memcpy(mlDInfo, seqState->stateML.table + seqState->stateML.state, sizeof(ZSTD_seqSymbol)); + ZSTD_memcpy(ofDInfo, seqState->stateOffb.table + seqState->stateOffb.state, sizeof(ZSTD_seqSymbol)); +#else const ZSTD_seqSymbol* const llDInfo = seqState->stateLL.table + seqState->stateLL.state; const ZSTD_seqSymbol* const mlDInfo = seqState->stateML.table + seqState->stateML.state; const ZSTD_seqSymbol* const ofDInfo = seqState->stateOffb.table + seqState->stateOffb.state; +#endif seq.matchLength = mlDInfo->baseValue; seq.litLength = llDInfo->baseValue; { U32 const ofBase = ofDInfo->baseValue; @@ -1186,28 +1265,31 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets) U32 const llnbBits = llDInfo->nbBits; U32 const mlnbBits = mlDInfo->nbBits; U32 const ofnbBits = ofDInfo->nbBits; + + assert(llBits <= MaxLLBits); + assert(mlBits <= MaxMLBits); + assert(ofBits <= MaxOff); /* * As gcc has better branch and block analyzers, sometimes it is only - * valuable to mark likelyness for clang, it gives around 3-4% of + * valuable to mark likeliness for clang, it gives around 3-4% of * performance. */ /* sequence */ { size_t offset; - #if defined(__clang__) - if (LIKELY(ofBits > 1)) { - #else if (ofBits > 1) { - #endif ZSTD_STATIC_ASSERT(ZSTD_lo_isLongOffset == 1); ZSTD_STATIC_ASSERT(LONG_OFFSETS_MAX_EXTRA_BITS_32 == 5); - assert(ofBits <= MaxOff); + ZSTD_STATIC_ASSERT(STREAM_ACCUMULATOR_MIN_32 > LONG_OFFSETS_MAX_EXTRA_BITS_32); + ZSTD_STATIC_ASSERT(STREAM_ACCUMULATOR_MIN_32 - LONG_OFFSETS_MAX_EXTRA_BITS_32 >= MaxMLBits); if (MEM_32bits() && longOffsets && (ofBits >= STREAM_ACCUMULATOR_MIN_32)) { - U32 const extraBits = ofBits - MIN(ofBits, 32 - seqState->DStream.bitsConsumed); + /* Always read extra bits, this keeps the logic simple, + * avoids branches, and avoids accidentally reading 0 bits. + */ + U32 const extraBits = LONG_OFFSETS_MAX_EXTRA_BITS_32; offset = ofBase + (BIT_readBitsFast(&seqState->DStream, ofBits - extraBits) << extraBits); BIT_reloadDStream(&seqState->DStream); - if (extraBits) offset += BIT_readBitsFast(&seqState->DStream, extraBits); - assert(extraBits <= LONG_OFFSETS_MAX_EXTRA_BITS_32); /* to avoid another reload */ + offset += BIT_readBitsFast(&seqState->DStream, extraBits); } else { offset = ofBase + BIT_readBitsFast(&seqState->DStream, ofBits/*>0*/); /* <= (ZSTD_WINDOWLOG_MAX-1) bits */ if (MEM_32bits()) BIT_reloadDStream(&seqState->DStream); @@ -1224,7 +1306,7 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets) } else { offset = ofBase + ll0 + BIT_readBitsFast(&seqState->DStream, 1); { size_t temp = (offset==3) ? seqState->prevOffset[0] - 1 : seqState->prevOffset[offset]; - temp += !temp; /* 0 is not valid; input is corrupted; force offset to 1 */ + temp -= !temp; /* 0 is not valid: input corrupted => force offset to -1 => corruption detected at execSequence */ if (offset != 1) seqState->prevOffset[2] = seqState->prevOffset[1]; seqState->prevOffset[1] = seqState->prevOffset[0]; seqState->prevOffset[0] = offset = temp; @@ -1232,11 +1314,7 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets) seq.offset = offset; } - #if defined(__clang__) - if (UNLIKELY(mlBits > 0)) - #else if (mlBits > 0) - #endif seq.matchLength += BIT_readBitsFast(&seqState->DStream, mlBits/*>0*/); if (MEM_32bits() && (mlBits+llBits >= STREAM_ACCUMULATOR_MIN_32-LONG_OFFSETS_MAX_EXTRA_BITS_32)) @@ -1246,11 +1324,7 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets) /* Ensure there are enough bits to read the rest of data in 64-bit mode. */ ZSTD_STATIC_ASSERT(16+LLFSELog+MLFSELog+OffFSELog < STREAM_ACCUMULATOR_MIN_64); - #if defined(__clang__) - if (UNLIKELY(llBits > 0)) - #else if (llBits > 0) - #endif seq.litLength += BIT_readBitsFast(&seqState->DStream, llBits/*>0*/); if (MEM_32bits()) @@ -1259,17 +1333,22 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets) DEBUGLOG(6, "seq: litL=%u, matchL=%u, offset=%u", (U32)seq.litLength, (U32)seq.matchLength, (U32)seq.offset); - ZSTD_updateFseStateWithDInfo(&seqState->stateLL, &seqState->DStream, llNext, llnbBits); /* <= 9 bits */ - ZSTD_updateFseStateWithDInfo(&seqState->stateML, &seqState->DStream, mlNext, mlnbBits); /* <= 9 bits */ - if (MEM_32bits()) BIT_reloadDStream(&seqState->DStream); /* <= 18 bits */ - ZSTD_updateFseStateWithDInfo(&seqState->stateOffb, &seqState->DStream, ofNext, ofnbBits); /* <= 8 bits */ + if (!isLastSeq) { + /* don't update FSE state for last Sequence */ + ZSTD_updateFseStateWithDInfo(&seqState->stateLL, &seqState->DStream, llNext, llnbBits); /* <= 9 bits */ + ZSTD_updateFseStateWithDInfo(&seqState->stateML, &seqState->DStream, mlNext, mlnbBits); /* <= 9 bits */ + if (MEM_32bits()) BIT_reloadDStream(&seqState->DStream); /* <= 18 bits */ + ZSTD_updateFseStateWithDInfo(&seqState->stateOffb, &seqState->DStream, ofNext, ofnbBits); /* <= 8 bits */ + BIT_reloadDStream(&seqState->DStream); + } } return seq; } -#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION -MEM_STATIC int ZSTD_dictionaryIsActive(ZSTD_DCtx const* dctx, BYTE const* prefixStart, BYTE const* oLitEnd) +#if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) +#if DEBUGLEVEL >= 1 +static int ZSTD_dictionaryIsActive(ZSTD_DCtx const* dctx, BYTE const* prefixStart, BYTE const* oLitEnd) { size_t const windowSize = dctx->fParams.windowSize; /* No dictionary used. */ @@ -1283,30 +1362,33 @@ MEM_STATIC int ZSTD_dictionaryIsActive(ZSTD_DCtx const* dctx, BYTE const* prefix /* Dictionary is active. */ return 1; } +#endif -MEM_STATIC void ZSTD_assertValidSequence( +static void ZSTD_assertValidSequence( ZSTD_DCtx const* dctx, BYTE const* op, BYTE const* oend, seq_t const seq, BYTE const* prefixStart, BYTE const* virtualStart) { #if DEBUGLEVEL >= 1 - size_t const windowSize = dctx->fParams.windowSize; - size_t const sequenceSize = seq.litLength + seq.matchLength; - BYTE const* const oLitEnd = op + seq.litLength; - DEBUGLOG(6, "Checking sequence: litL=%u matchL=%u offset=%u", - (U32)seq.litLength, (U32)seq.matchLength, (U32)seq.offset); - assert(op <= oend); - assert((size_t)(oend - op) >= sequenceSize); - assert(sequenceSize <= ZSTD_BLOCKSIZE_MAX); - if (ZSTD_dictionaryIsActive(dctx, prefixStart, oLitEnd)) { - size_t const dictSize = (size_t)((char const*)dctx->dictContentEndForFuzzing - (char const*)dctx->dictContentBeginForFuzzing); - /* Offset must be within the dictionary. */ - assert(seq.offset <= (size_t)(oLitEnd - virtualStart)); - assert(seq.offset <= windowSize + dictSize); - } else { - /* Offset must be within our window. */ - assert(seq.offset <= windowSize); + if (dctx->isFrameDecompression) { + size_t const windowSize = dctx->fParams.windowSize; + size_t const sequenceSize = seq.litLength + seq.matchLength; + BYTE const* const oLitEnd = op + seq.litLength; + DEBUGLOG(6, "Checking sequence: litL=%u matchL=%u offset=%u", + (U32)seq.litLength, (U32)seq.matchLength, (U32)seq.offset); + assert(op <= oend); + assert((size_t)(oend - op) >= sequenceSize); + assert(sequenceSize <= ZSTD_blockSizeMax(dctx)); + if (ZSTD_dictionaryIsActive(dctx, prefixStart, oLitEnd)) { + size_t const dictSize = (size_t)((char const*)dctx->dictContentEndForFuzzing - (char const*)dctx->dictContentBeginForFuzzing); + /* Offset must be within the dictionary. */ + assert(seq.offset <= (size_t)(oLitEnd - virtualStart)); + assert(seq.offset <= windowSize + dictSize); + } else { + /* Offset must be within our window. */ + assert(seq.offset <= windowSize); + } } #else (void)dctx, (void)op, (void)oend, (void)seq, (void)prefixStart, (void)virtualStart; @@ -1322,23 +1404,21 @@ DONT_VECTORIZE ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { const BYTE* ip = (const BYTE*)seqStart; const BYTE* const iend = ip + seqSize; BYTE* const ostart = (BYTE*)dst; - BYTE* const oend = ostart + maxDstSize; + BYTE* const oend = ZSTD_maybeNullPtrAdd(ostart, maxDstSize); BYTE* op = ostart; const BYTE* litPtr = dctx->litPtr; const BYTE* litBufferEnd = dctx->litBufferEnd; const BYTE* const prefixStart = (const BYTE*) (dctx->prefixStart); const BYTE* const vBase = (const BYTE*) (dctx->virtualStart); const BYTE* const dictEnd = (const BYTE*) (dctx->dictEnd); - DEBUGLOG(5, "ZSTD_decompressSequences_bodySplitLitBuffer"); - (void)frame; + DEBUGLOG(5, "ZSTD_decompressSequences_bodySplitLitBuffer (%i seqs)", nbSeq); - /* Regen sequences */ + /* Literals are split between internal buffer & output buffer */ if (nbSeq) { seqState_t seqState; dctx->fseEntropy = 1; @@ -1357,8 +1437,7 @@ ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx, BIT_DStream_completed < BIT_DStream_overflow); /* decompress without overrunning litPtr begins */ - { - seq_t sequence = ZSTD_decodeSequence(&seqState, isLongOffset); + { seq_t sequence = {0,0,0}; /* some static analyzer believe that @sequence is not initialized (it necessarily is, since for(;;) loop as at least one iteration) */ /* Align the decompression loop to 32 + 16 bytes. * * zstd compiled with gcc-9 on an Intel i9-9900k shows 10% decompression @@ -1420,27 +1499,26 @@ ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx, #endif /* Handle the initial state where litBuffer is currently split between dst and litExtraBuffer */ - for (; litPtr + sequence.litLength <= dctx->litBufferEnd; ) { - size_t const oneSeqSize = ZSTD_execSequenceSplitLitBuffer(op, oend, litPtr + sequence.litLength - WILDCOPY_OVERLENGTH, sequence, &litPtr, litBufferEnd, prefixStart, vBase, dictEnd); + for ( ; nbSeq; nbSeq--) { + sequence = ZSTD_decodeSequence(&seqState, isLongOffset, nbSeq==1); + if (litPtr + sequence.litLength > dctx->litBufferEnd) break; + { size_t const oneSeqSize = ZSTD_execSequenceSplitLitBuffer(op, oend, litPtr + sequence.litLength - WILDCOPY_OVERLENGTH, sequence, &litPtr, litBufferEnd, prefixStart, vBase, dictEnd); #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) - assert(!ZSTD_isError(oneSeqSize)); - if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); + assert(!ZSTD_isError(oneSeqSize)); + ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); #endif - if (UNLIKELY(ZSTD_isError(oneSeqSize))) - return oneSeqSize; - DEBUGLOG(6, "regenerated sequence size : %u", (U32)oneSeqSize); - op += oneSeqSize; - if (UNLIKELY(!--nbSeq)) - break; - BIT_reloadDStream(&(seqState.DStream)); - sequence = ZSTD_decodeSequence(&seqState, isLongOffset); - } + if (UNLIKELY(ZSTD_isError(oneSeqSize))) + return oneSeqSize; + DEBUGLOG(6, "regenerated sequence size : %u", (U32)oneSeqSize); + op += oneSeqSize; + } } + DEBUGLOG(6, "reached: (litPtr + sequence.litLength > dctx->litBufferEnd)"); /* If there are more sequences, they will need to read literals from litExtraBuffer; copy over the remainder from dst and update litPtr and litEnd */ if (nbSeq > 0) { const size_t leftoverLit = dctx->litBufferEnd - litPtr; - if (leftoverLit) - { + DEBUGLOG(6, "There are %i sequences left, and %zu/%zu literals left in buffer", nbSeq, leftoverLit, sequence.litLength); + if (leftoverLit) { RETURN_ERROR_IF(leftoverLit > (size_t)(oend - op), dstSize_tooSmall, "remaining lit must fit within dstBuffer"); ZSTD_safecopyDstBeforeSrc(op, litPtr, leftoverLit); sequence.litLength -= leftoverLit; @@ -1449,24 +1527,22 @@ ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx, litPtr = dctx->litExtraBuffer; litBufferEnd = dctx->litExtraBuffer + ZSTD_LITBUFFEREXTRASIZE; dctx->litBufferLocation = ZSTD_not_in_dst; - { - size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequence, &litPtr, litBufferEnd, prefixStart, vBase, dictEnd); + { size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequence, &litPtr, litBufferEnd, prefixStart, vBase, dictEnd); #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) assert(!ZSTD_isError(oneSeqSize)); - if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); + ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); #endif if (UNLIKELY(ZSTD_isError(oneSeqSize))) return oneSeqSize; DEBUGLOG(6, "regenerated sequence size : %u", (U32)oneSeqSize); op += oneSeqSize; - if (--nbSeq) - BIT_reloadDStream(&(seqState.DStream)); } + nbSeq--; } } - if (nbSeq > 0) /* there is remaining lit from extra buffer */ - { + if (nbSeq > 0) { + /* there is remaining lit from extra buffer */ #if defined(__x86_64__) __asm__(".p2align 6"); @@ -1485,35 +1561,34 @@ ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx, # endif #endif - for (; ; ) { - seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset); + for ( ; nbSeq ; nbSeq--) { + seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset, nbSeq==1); size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequence, &litPtr, litBufferEnd, prefixStart, vBase, dictEnd); #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) assert(!ZSTD_isError(oneSeqSize)); - if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); + ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); #endif if (UNLIKELY(ZSTD_isError(oneSeqSize))) return oneSeqSize; DEBUGLOG(6, "regenerated sequence size : %u", (U32)oneSeqSize); op += oneSeqSize; - if (UNLIKELY(!--nbSeq)) - break; - BIT_reloadDStream(&(seqState.DStream)); } } /* check if reached exact end */ DEBUGLOG(5, "ZSTD_decompressSequences_bodySplitLitBuffer: after decode loop, remaining nbSeq : %i", nbSeq); RETURN_ERROR_IF(nbSeq, corruption_detected, ""); - RETURN_ERROR_IF(BIT_reloadDStream(&seqState.DStream) < BIT_DStream_completed, corruption_detected, ""); + DEBUGLOG(5, "bitStream : start=%p, ptr=%p, bitsConsumed=%u", seqState.DStream.start, seqState.DStream.ptr, seqState.DStream.bitsConsumed); + RETURN_ERROR_IF(!BIT_endOfDStream(&seqState.DStream), corruption_detected, ""); /* save reps for next block */ { U32 i; for (i=0; ientropy.rep[i] = (U32)(seqState.prevOffset[i]); } } /* last literal segment */ - if (dctx->litBufferLocation == ZSTD_split) /* split hasn't been reached yet, first get dst then copy litExtraBuffer */ - { - size_t const lastLLSize = litBufferEnd - litPtr; + if (dctx->litBufferLocation == ZSTD_split) { + /* split hasn't been reached yet, first get dst then copy litExtraBuffer */ + size_t const lastLLSize = (size_t)(litBufferEnd - litPtr); + DEBUGLOG(6, "copy last literals from segment : %u", (U32)lastLLSize); RETURN_ERROR_IF(lastLLSize > (size_t)(oend - op), dstSize_tooSmall, ""); if (op != NULL) { ZSTD_memmove(op, litPtr, lastLLSize); @@ -1523,15 +1598,17 @@ ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx, litBufferEnd = dctx->litExtraBuffer + ZSTD_LITBUFFEREXTRASIZE; dctx->litBufferLocation = ZSTD_not_in_dst; } - { size_t const lastLLSize = litBufferEnd - litPtr; + /* copy last literals from internal buffer */ + { size_t const lastLLSize = (size_t)(litBufferEnd - litPtr); + DEBUGLOG(6, "copy last literals from internal buffer : %u", (U32)lastLLSize); RETURN_ERROR_IF(lastLLSize > (size_t)(oend-op), dstSize_tooSmall, ""); if (op != NULL) { ZSTD_memcpy(op, litPtr, lastLLSize); op += lastLLSize; - } - } + } } - return op-ostart; + DEBUGLOG(6, "decoded block of size %u bytes", (U32)(op - ostart)); + return (size_t)(op - ostart); } FORCE_INLINE_TEMPLATE size_t @@ -1539,21 +1616,19 @@ DONT_VECTORIZE ZSTD_decompressSequences_body(ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { const BYTE* ip = (const BYTE*)seqStart; const BYTE* const iend = ip + seqSize; BYTE* const ostart = (BYTE*)dst; - BYTE* const oend = dctx->litBufferLocation == ZSTD_not_in_dst ? ostart + maxDstSize : dctx->litBuffer; + BYTE* const oend = dctx->litBufferLocation == ZSTD_not_in_dst ? ZSTD_maybeNullPtrAdd(ostart, maxDstSize) : dctx->litBuffer; BYTE* op = ostart; const BYTE* litPtr = dctx->litPtr; const BYTE* const litEnd = litPtr + dctx->litSize; const BYTE* const prefixStart = (const BYTE*)(dctx->prefixStart); const BYTE* const vBase = (const BYTE*)(dctx->virtualStart); const BYTE* const dictEnd = (const BYTE*)(dctx->dictEnd); - DEBUGLOG(5, "ZSTD_decompressSequences_body"); - (void)frame; + DEBUGLOG(5, "ZSTD_decompressSequences_body: nbSeq = %d", nbSeq); /* Regen sequences */ if (nbSeq) { @@ -1568,11 +1643,6 @@ ZSTD_decompressSequences_body(ZSTD_DCtx* dctx, ZSTD_initFseState(&seqState.stateML, &seqState.DStream, dctx->MLTptr); assert(dst != NULL); - ZSTD_STATIC_ASSERT( - BIT_DStream_unfinished < BIT_DStream_completed && - BIT_DStream_endOfBuffer < BIT_DStream_completed && - BIT_DStream_completed < BIT_DStream_overflow); - #if defined(__x86_64__) __asm__(".p2align 6"); __asm__("nop"); @@ -1587,73 +1657,70 @@ ZSTD_decompressSequences_body(ZSTD_DCtx* dctx, # endif #endif - for ( ; ; ) { - seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset); + for ( ; nbSeq ; nbSeq--) { + seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset, nbSeq==1); size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequence, &litPtr, litEnd, prefixStart, vBase, dictEnd); #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) assert(!ZSTD_isError(oneSeqSize)); - if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); + ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); #endif if (UNLIKELY(ZSTD_isError(oneSeqSize))) return oneSeqSize; DEBUGLOG(6, "regenerated sequence size : %u", (U32)oneSeqSize); op += oneSeqSize; - if (UNLIKELY(!--nbSeq)) - break; - BIT_reloadDStream(&(seqState.DStream)); } /* check if reached exact end */ - DEBUGLOG(5, "ZSTD_decompressSequences_body: after decode loop, remaining nbSeq : %i", nbSeq); - RETURN_ERROR_IF(nbSeq, corruption_detected, ""); - RETURN_ERROR_IF(BIT_reloadDStream(&seqState.DStream) < BIT_DStream_completed, corruption_detected, ""); + assert(nbSeq == 0); + RETURN_ERROR_IF(!BIT_endOfDStream(&seqState.DStream), corruption_detected, ""); /* save reps for next block */ { U32 i; for (i=0; ientropy.rep[i] = (U32)(seqState.prevOffset[i]); } } /* last literal segment */ - { size_t const lastLLSize = litEnd - litPtr; + { size_t const lastLLSize = (size_t)(litEnd - litPtr); + DEBUGLOG(6, "copy last literals : %u", (U32)lastLLSize); RETURN_ERROR_IF(lastLLSize > (size_t)(oend-op), dstSize_tooSmall, ""); if (op != NULL) { ZSTD_memcpy(op, litPtr, lastLLSize); op += lastLLSize; - } - } + } } - return op-ostart; + DEBUGLOG(6, "decoded block of size %u bytes", (U32)(op - ostart)); + return (size_t)(op - ostart); } static size_t ZSTD_decompressSequences_default(ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { - return ZSTD_decompressSequences_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequences_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } static size_t ZSTD_decompressSequencesSplitLitBuffer_default(ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { - return ZSTD_decompressSequences_bodySplitLitBuffer(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequences_bodySplitLitBuffer(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } #endif /* ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG */ #ifndef ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT -FORCE_INLINE_TEMPLATE size_t -ZSTD_prefetchMatch(size_t prefetchPos, seq_t const sequence, +FORCE_INLINE_TEMPLATE + +size_t ZSTD_prefetchMatch(size_t prefetchPos, seq_t const sequence, const BYTE* const prefixStart, const BYTE* const dictEnd) { prefetchPos += sequence.litLength; { const BYTE* const matchBase = (sequence.offset > prefetchPos) ? dictEnd : prefixStart; - const BYTE* const match = matchBase + prefetchPos - sequence.offset; /* note : this operation can overflow when seq.offset is really too large, which can only happen when input is corrupted. - * No consequence though : memory address is only used for prefetching, not for dereferencing */ + /* note : this operation can overflow when seq.offset is really too large, which can only happen when input is corrupted. + * No consequence though : memory address is only used for prefetching, not for dereferencing */ + const BYTE* const match = ZSTD_wrappedPtrSub(ZSTD_wrappedPtrAdd(matchBase, prefetchPos), sequence.offset); PREFETCH_L1(match); PREFETCH_L1(match+CACHELINE_SIZE); /* note : it's safe to invoke PREFETCH() on any memory address, including invalid ones */ } return prefetchPos + sequence.matchLength; @@ -1668,20 +1735,18 @@ ZSTD_decompressSequencesLong_body( ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { const BYTE* ip = (const BYTE*)seqStart; const BYTE* const iend = ip + seqSize; BYTE* const ostart = (BYTE*)dst; - BYTE* const oend = dctx->litBufferLocation == ZSTD_in_dst ? dctx->litBuffer : ostart + maxDstSize; + BYTE* const oend = dctx->litBufferLocation == ZSTD_in_dst ? dctx->litBuffer : ZSTD_maybeNullPtrAdd(ostart, maxDstSize); BYTE* op = ostart; const BYTE* litPtr = dctx->litPtr; const BYTE* litBufferEnd = dctx->litBufferEnd; const BYTE* const prefixStart = (const BYTE*) (dctx->prefixStart); const BYTE* const dictStart = (const BYTE*) (dctx->virtualStart); const BYTE* const dictEnd = (const BYTE*) (dctx->dictEnd); - (void)frame; /* Regen sequences */ if (nbSeq) { @@ -1706,20 +1771,17 @@ ZSTD_decompressSequencesLong_body( ZSTD_initFseState(&seqState.stateML, &seqState.DStream, dctx->MLTptr); /* prepare in advance */ - for (seqNb=0; (BIT_reloadDStream(&seqState.DStream) <= BIT_DStream_completed) && (seqNblitBufferLocation == ZSTD_split && litPtr + sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK].litLength > dctx->litBufferEnd) - { + if (dctx->litBufferLocation == ZSTD_split && litPtr + sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK].litLength > dctx->litBufferEnd) { /* lit buffer is reaching split point, empty out the first buffer and transition to litExtraBuffer */ const size_t leftoverLit = dctx->litBufferEnd - litPtr; if (leftoverLit) @@ -1732,26 +1794,26 @@ ZSTD_decompressSequencesLong_body( litPtr = dctx->litExtraBuffer; litBufferEnd = dctx->litExtraBuffer + ZSTD_LITBUFFEREXTRASIZE; dctx->litBufferLocation = ZSTD_not_in_dst; - oneSeqSize = ZSTD_execSequence(op, oend, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], &litPtr, litBufferEnd, prefixStart, dictStart, dictEnd); + { size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], &litPtr, litBufferEnd, prefixStart, dictStart, dictEnd); #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) - assert(!ZSTD_isError(oneSeqSize)); - if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], prefixStart, dictStart); + assert(!ZSTD_isError(oneSeqSize)); + ZSTD_assertValidSequence(dctx, op, oend, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], prefixStart, dictStart); #endif - if (ZSTD_isError(oneSeqSize)) return oneSeqSize; + if (ZSTD_isError(oneSeqSize)) return oneSeqSize; - prefetchPos = ZSTD_prefetchMatch(prefetchPos, sequence, prefixStart, dictEnd); - sequences[seqNb & STORED_SEQS_MASK] = sequence; - op += oneSeqSize; - } + prefetchPos = ZSTD_prefetchMatch(prefetchPos, sequence, prefixStart, dictEnd); + sequences[seqNb & STORED_SEQS_MASK] = sequence; + op += oneSeqSize; + } } else { /* lit buffer is either wholly contained in first or second split, or not split at all*/ - oneSeqSize = dctx->litBufferLocation == ZSTD_split ? + size_t const oneSeqSize = dctx->litBufferLocation == ZSTD_split ? ZSTD_execSequenceSplitLitBuffer(op, oend, litPtr + sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK].litLength - WILDCOPY_OVERLENGTH, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], &litPtr, litBufferEnd, prefixStart, dictStart, dictEnd) : ZSTD_execSequence(op, oend, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], &litPtr, litBufferEnd, prefixStart, dictStart, dictEnd); #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) assert(!ZSTD_isError(oneSeqSize)); - if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], prefixStart, dictStart); + ZSTD_assertValidSequence(dctx, op, oend, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], prefixStart, dictStart); #endif if (ZSTD_isError(oneSeqSize)) return oneSeqSize; @@ -1760,17 +1822,15 @@ ZSTD_decompressSequencesLong_body( op += oneSeqSize; } } - RETURN_ERROR_IF(seqNblitBufferLocation == ZSTD_split && litPtr + sequence->litLength > dctx->litBufferEnd) - { + if (dctx->litBufferLocation == ZSTD_split && litPtr + sequence->litLength > dctx->litBufferEnd) { const size_t leftoverLit = dctx->litBufferEnd - litPtr; - if (leftoverLit) - { + if (leftoverLit) { RETURN_ERROR_IF(leftoverLit > (size_t)(oend - op), dstSize_tooSmall, "remaining lit must fit within dstBuffer"); ZSTD_safecopyDstBeforeSrc(op, litPtr, leftoverLit); sequence->litLength -= leftoverLit; @@ -1779,11 +1839,10 @@ ZSTD_decompressSequencesLong_body( litPtr = dctx->litExtraBuffer; litBufferEnd = dctx->litExtraBuffer + ZSTD_LITBUFFEREXTRASIZE; dctx->litBufferLocation = ZSTD_not_in_dst; - { - size_t const oneSeqSize = ZSTD_execSequence(op, oend, *sequence, &litPtr, litBufferEnd, prefixStart, dictStart, dictEnd); + { size_t const oneSeqSize = ZSTD_execSequence(op, oend, *sequence, &litPtr, litBufferEnd, prefixStart, dictStart, dictEnd); #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) assert(!ZSTD_isError(oneSeqSize)); - if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequences[seqNb&STORED_SEQS_MASK], prefixStart, dictStart); + ZSTD_assertValidSequence(dctx, op, oend, sequences[seqNb&STORED_SEQS_MASK], prefixStart, dictStart); #endif if (ZSTD_isError(oneSeqSize)) return oneSeqSize; op += oneSeqSize; @@ -1796,7 +1855,7 @@ ZSTD_decompressSequencesLong_body( ZSTD_execSequence(op, oend, *sequence, &litPtr, litBufferEnd, prefixStart, dictStart, dictEnd); #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) assert(!ZSTD_isError(oneSeqSize)); - if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequences[seqNb&STORED_SEQS_MASK], prefixStart, dictStart); + ZSTD_assertValidSequence(dctx, op, oend, sequences[seqNb&STORED_SEQS_MASK], prefixStart, dictStart); #endif if (ZSTD_isError(oneSeqSize)) return oneSeqSize; op += oneSeqSize; @@ -1808,8 +1867,7 @@ ZSTD_decompressSequencesLong_body( } /* last literal segment */ - if (dctx->litBufferLocation == ZSTD_split) /* first deplete literal buffer in dst, then copy litExtraBuffer */ - { + if (dctx->litBufferLocation == ZSTD_split) { /* first deplete literal buffer in dst, then copy litExtraBuffer */ size_t const lastLLSize = litBufferEnd - litPtr; RETURN_ERROR_IF(lastLLSize > (size_t)(oend - op), dstSize_tooSmall, ""); if (op != NULL) { @@ -1827,17 +1885,16 @@ ZSTD_decompressSequencesLong_body( } } - return op-ostart; + return (size_t)(op - ostart); } static size_t ZSTD_decompressSequencesLong_default(ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { - return ZSTD_decompressSequencesLong_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequencesLong_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } #endif /* ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT */ @@ -1851,20 +1908,18 @@ DONT_VECTORIZE ZSTD_decompressSequences_bmi2(ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { - return ZSTD_decompressSequences_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequences_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } static BMI2_TARGET_ATTRIBUTE size_t DONT_VECTORIZE ZSTD_decompressSequencesSplitLitBuffer_bmi2(ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { - return ZSTD_decompressSequences_bodySplitLitBuffer(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequences_bodySplitLitBuffer(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } #endif /* ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG */ @@ -1873,50 +1928,40 @@ static BMI2_TARGET_ATTRIBUTE size_t ZSTD_decompressSequencesLong_bmi2(ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { - return ZSTD_decompressSequencesLong_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequencesLong_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } #endif /* ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT */ #endif /* DYNAMIC_BMI2 */ -typedef size_t (*ZSTD_decompressSequences_t)( - ZSTD_DCtx* dctx, - void* dst, size_t maxDstSize, - const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame); - #ifndef ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG static size_t ZSTD_decompressSequences(ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { DEBUGLOG(5, "ZSTD_decompressSequences"); #if DYNAMIC_BMI2 if (ZSTD_DCtx_get_bmi2(dctx)) { - return ZSTD_decompressSequences_bmi2(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequences_bmi2(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } #endif - return ZSTD_decompressSequences_default(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequences_default(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } static size_t ZSTD_decompressSequencesSplitLitBuffer(ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { DEBUGLOG(5, "ZSTD_decompressSequencesSplitLitBuffer"); #if DYNAMIC_BMI2 if (ZSTD_DCtx_get_bmi2(dctx)) { - return ZSTD_decompressSequencesSplitLitBuffer_bmi2(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequencesSplitLitBuffer_bmi2(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } #endif - return ZSTD_decompressSequencesSplitLitBuffer_default(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequencesSplitLitBuffer_default(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } #endif /* ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG */ @@ -1931,69 +1976,114 @@ static size_t ZSTD_decompressSequencesLong(ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { DEBUGLOG(5, "ZSTD_decompressSequencesLong"); #if DYNAMIC_BMI2 if (ZSTD_DCtx_get_bmi2(dctx)) { - return ZSTD_decompressSequencesLong_bmi2(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequencesLong_bmi2(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } #endif - return ZSTD_decompressSequencesLong_default(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequencesLong_default(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } #endif /* ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT */ +/* + * @returns The total size of the history referenceable by zstd, including + * both the prefix and the extDict. At @p op any offset larger than this + * is invalid. + */ +static size_t ZSTD_totalHistorySize(BYTE* op, BYTE const* virtualStart) +{ + return (size_t)(op - virtualStart); +} + +typedef struct { + unsigned longOffsetShare; + unsigned maxNbAdditionalBits; +} ZSTD_OffsetInfo; -#if !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT) && \ - !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG) -/* ZSTD_getLongOffsetsShare() : +/* ZSTD_getOffsetInfo() : * condition : offTable must be valid * @return : "share" of long offsets (arbitrarily defined as > (1<<23)) - * compared to maximum possible of (1< 22) total += 1; + ZSTD_OffsetInfo info = {0, 0}; + /* If nbSeq == 0, then the offTable is uninitialized, but we have + * no sequences, so both values should be 0. + */ + if (nbSeq != 0) { + const void* ptr = offTable; + U32 const tableLog = ((const ZSTD_seqSymbol_header*)ptr)[0].tableLog; + const ZSTD_seqSymbol* table = offTable + 1; + U32 const max = 1 << tableLog; + U32 u; + DEBUGLOG(5, "ZSTD_getLongOffsetsShare: (tableLog=%u)", tableLog); + + assert(max <= (1 << OffFSELog)); /* max not too large */ + for (u=0; u 22) info.longOffsetShare += 1; + } + + assert(tableLog <= OffFSELog); + info.longOffsetShare <<= (OffFSELog - tableLog); /* scale to OffFSELog */ } - assert(tableLog <= OffFSELog); - total <<= (OffFSELog - tableLog); /* scale to OffFSELog */ + return info; +} - return total; +/* + * @returns The maximum offset we can decode in one read of our bitstream, without + * reloading more bits in the middle of the offset bits read. Any offsets larger + * than this must use the long offset decoder. + */ +static size_t ZSTD_maxShortOffset(void) +{ + if (MEM_64bits()) { + /* We can decode any offset without reloading bits. + * This might change if the max window size grows. + */ + ZSTD_STATIC_ASSERT(ZSTD_WINDOWLOG_MAX <= 31); + return (size_t)-1; + } else { + /* The maximum offBase is (1 << (STREAM_ACCUMULATOR_MIN + 1)) - 1. + * This offBase would require STREAM_ACCUMULATOR_MIN extra bits. + * Then we have to subtract ZSTD_REP_NUM to get the maximum possible offset. + */ + size_t const maxOffbase = ((size_t)1 << (STREAM_ACCUMULATOR_MIN + 1)) - 1; + size_t const maxOffset = maxOffbase - ZSTD_REP_NUM; + assert(ZSTD_highbit32((U32)maxOffbase) == STREAM_ACCUMULATOR_MIN); + return maxOffset; + } } -#endif size_t ZSTD_decompressBlock_internal(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, - const void* src, size_t srcSize, const int frame, const streaming_operation streaming) + const void* src, size_t srcSize, const streaming_operation streaming) { /* blockType == blockCompressed */ const BYTE* ip = (const BYTE*)src; - /* isLongOffset must be true if there are long offsets. - * Offsets are long if they are larger than 2^STREAM_ACCUMULATOR_MIN. - * We don't expect that to be the case in 64-bit mode. - * In block mode, window size is not known, so we have to be conservative. - * (note: but it could be evaluated from current-lowLimit) - */ - ZSTD_longOffset_e const isLongOffset = (ZSTD_longOffset_e)(MEM_32bits() && (!frame || (dctx->fParams.windowSize > (1ULL << STREAM_ACCUMULATOR_MIN)))); - DEBUGLOG(5, "ZSTD_decompressBlock_internal (size : %u)", (U32)srcSize); - - RETURN_ERROR_IF(srcSize >= ZSTD_BLOCKSIZE_MAX, srcSize_wrong, ""); + DEBUGLOG(5, "ZSTD_decompressBlock_internal (cSize : %u)", (unsigned)srcSize); + + /* Note : the wording of the specification + * allows compressed block to be sized exactly ZSTD_blockSizeMax(dctx). + * This generally does not happen, as it makes little sense, + * since an uncompressed block would feature same size and have no decompression cost. + * Also, note that decoder from reference libzstd before < v1.5.4 + * would consider this edge case as an error. + * As a consequence, avoid generating compressed blocks of size ZSTD_blockSizeMax(dctx) + * for broader compatibility with the deployed ecosystem of zstd decoders */ + RETURN_ERROR_IF(srcSize > ZSTD_blockSizeMax(dctx), srcSize_wrong, ""); /* Decode literals section */ { size_t const litCSize = ZSTD_decodeLiteralsBlock(dctx, src, srcSize, dst, dstCapacity, streaming); - DEBUGLOG(5, "ZSTD_decodeLiteralsBlock : %u", (U32)litCSize); + DEBUGLOG(5, "ZSTD_decodeLiteralsBlock : cSize=%u, nbLiterals=%zu", (U32)litCSize, dctx->litSize); if (ZSTD_isError(litCSize)) return litCSize; ip += litCSize; srcSize -= litCSize; @@ -2001,6 +2091,23 @@ ZSTD_decompressBlock_internal(ZSTD_DCtx* dctx, /* Build Decoding Tables */ { + /* Compute the maximum block size, which must also work when !frame and fParams are unset. + * Additionally, take the min with dstCapacity to ensure that the totalHistorySize fits in a size_t. + */ + size_t const blockSizeMax = MIN(dstCapacity, ZSTD_blockSizeMax(dctx)); + size_t const totalHistorySize = ZSTD_totalHistorySize(ZSTD_maybeNullPtrAdd((BYTE*)dst, blockSizeMax), (BYTE const*)dctx->virtualStart); + /* isLongOffset must be true if there are long offsets. + * Offsets are long if they are larger than ZSTD_maxShortOffset(). + * We don't expect that to be the case in 64-bit mode. + * + * We check here to see if our history is large enough to allow long offsets. + * If it isn't, then we can't possible have (valid) long offsets. If the offset + * is invalid, then it is okay to read it incorrectly. + * + * If isLongOffsets is true, then we will later check our decoding table to see + * if it is even possible to generate long offsets. + */ + ZSTD_longOffset_e isLongOffset = (ZSTD_longOffset_e)(MEM_32bits() && (totalHistorySize > ZSTD_maxShortOffset())); /* These macros control at build-time which decompressor implementation * we use. If neither is defined, we do some inspection and dispatch at * runtime. @@ -2008,6 +2115,11 @@ ZSTD_decompressBlock_internal(ZSTD_DCtx* dctx, #if !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT) && \ !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG) int usePrefetchDecoder = dctx->ddictIsCold; +#else + /* Set to 1 to avoid computing offset info if we don't need to. + * Otherwise this value is ignored. + */ + int usePrefetchDecoder = 1; #endif int nbSeq; size_t const seqHSize = ZSTD_decodeSeqHeaders(dctx, &nbSeq, ip, srcSize); @@ -2015,40 +2127,55 @@ ZSTD_decompressBlock_internal(ZSTD_DCtx* dctx, ip += seqHSize; srcSize -= seqHSize; - RETURN_ERROR_IF(dst == NULL && nbSeq > 0, dstSize_tooSmall, "NULL not handled"); + RETURN_ERROR_IF((dst == NULL || dstCapacity == 0) && nbSeq > 0, dstSize_tooSmall, "NULL not handled"); + RETURN_ERROR_IF(MEM_64bits() && sizeof(size_t) == sizeof(void*) && (size_t)(-1) - (size_t)dst < (size_t)(1 << 20), dstSize_tooSmall, + "invalid dst"); -#if !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT) && \ - !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG) - if ( !usePrefetchDecoder - && (!frame || (dctx->fParams.windowSize > (1<<24))) - && (nbSeq>ADVANCED_SEQS) ) { /* could probably use a larger nbSeq limit */ - U32 const shareLongOffsets = ZSTD_getLongOffsetsShare(dctx->OFTptr); - U32 const minShare = MEM_64bits() ? 7 : 20; /* heuristic values, correspond to 2.73% and 7.81% */ - usePrefetchDecoder = (shareLongOffsets >= minShare); + /* If we could potentially have long offsets, or we might want to use the prefetch decoder, + * compute information about the share of long offsets, and the maximum nbAdditionalBits. + * NOTE: could probably use a larger nbSeq limit + */ + if (isLongOffset || (!usePrefetchDecoder && (totalHistorySize > (1u << 24)) && (nbSeq > 8))) { + ZSTD_OffsetInfo const info = ZSTD_getOffsetInfo(dctx->OFTptr, nbSeq); + if (isLongOffset && info.maxNbAdditionalBits <= STREAM_ACCUMULATOR_MIN) { + /* If isLongOffset, but the maximum number of additional bits that we see in our table is small + * enough, then we know it is impossible to have too long an offset in this block, so we can + * use the regular offset decoder. + */ + isLongOffset = ZSTD_lo_isRegularOffset; + } + if (!usePrefetchDecoder) { + U32 const minShare = MEM_64bits() ? 7 : 20; /* heuristic values, correspond to 2.73% and 7.81% */ + usePrefetchDecoder = (info.longOffsetShare >= minShare); + } } -#endif dctx->ddictIsCold = 0; #if !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT) && \ !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG) - if (usePrefetchDecoder) + if (usePrefetchDecoder) { +#else + (void)usePrefetchDecoder; + { #endif #ifndef ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT - return ZSTD_decompressSequencesLong(dctx, dst, dstCapacity, ip, srcSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequencesLong(dctx, dst, dstCapacity, ip, srcSize, nbSeq, isLongOffset); #endif + } #ifndef ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG /* else */ if (dctx->litBufferLocation == ZSTD_split) - return ZSTD_decompressSequencesSplitLitBuffer(dctx, dst, dstCapacity, ip, srcSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequencesSplitLitBuffer(dctx, dst, dstCapacity, ip, srcSize, nbSeq, isLongOffset); else - return ZSTD_decompressSequences(dctx, dst, dstCapacity, ip, srcSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequences(dctx, dst, dstCapacity, ip, srcSize, nbSeq, isLongOffset); #endif } } +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR void ZSTD_checkContinuity(ZSTD_DCtx* dctx, const void* dst, size_t dstSize) { if (dst != dctx->previousDstEnd && dstSize > 0) { /* not contiguous */ @@ -2060,13 +2187,24 @@ void ZSTD_checkContinuity(ZSTD_DCtx* dctx, const void* dst, size_t dstSize) } -size_t ZSTD_decompressBlock(ZSTD_DCtx* dctx, - void* dst, size_t dstCapacity, - const void* src, size_t srcSize) +size_t ZSTD_decompressBlock_deprecated(ZSTD_DCtx* dctx, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize) { size_t dSize; + dctx->isFrameDecompression = 0; ZSTD_checkContinuity(dctx, dst, dstCapacity); - dSize = ZSTD_decompressBlock_internal(dctx, dst, dstCapacity, src, srcSize, /* frame */ 0, not_streaming); + dSize = ZSTD_decompressBlock_internal(dctx, dst, dstCapacity, src, srcSize, not_streaming); + FORWARD_IF_ERROR(dSize, ""); dctx->previousDstEnd = (char*)dst + dSize; return dSize; } + + +/* NOTE: Must just wrap ZSTD_decompressBlock_deprecated() */ +size_t ZSTD_decompressBlock(ZSTD_DCtx* dctx, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize) +{ + return ZSTD_decompressBlock_deprecated(dctx, dst, dstCapacity, src, srcSize); +} diff --git a/lib/zstd/decompress/zstd_decompress_block.h b/lib/zstd/decompress/zstd_decompress_block.h index 3d2d57a5d25a..becffbd89364 100644 --- a/lib/zstd/decompress/zstd_decompress_block.h +++ b/lib/zstd/decompress/zstd_decompress_block.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -47,7 +48,7 @@ typedef enum { */ size_t ZSTD_decompressBlock_internal(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, - const void* src, size_t srcSize, const int frame, const streaming_operation streaming); + const void* src, size_t srcSize, const streaming_operation streaming); /* ZSTD_buildFSETable() : * generate FSE decoding table for one symbol (ll, ml or off) @@ -64,5 +65,10 @@ void ZSTD_buildFSETable(ZSTD_seqSymbol* dt, unsigned tableLog, void* wksp, size_t wkspSize, int bmi2); +/* Internal definition of ZSTD_decompressBlock() to avoid deprecation warnings. */ +size_t ZSTD_decompressBlock_deprecated(ZSTD_DCtx* dctx, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize); + #endif /* ZSTD_DEC_BLOCK_H */ diff --git a/lib/zstd/decompress/zstd_decompress_internal.h b/lib/zstd/decompress/zstd_decompress_internal.h index 98102edb6a83..2a225d1811c4 100644 --- a/lib/zstd/decompress/zstd_decompress_internal.h +++ b/lib/zstd/decompress/zstd_decompress_internal.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -75,12 +76,13 @@ static UNUSED_ATTR const U32 ML_base[MaxML+1] = { #define ZSTD_BUILD_FSE_TABLE_WKSP_SIZE (sizeof(S16) * (MaxSeq + 1) + (1u << MaxFSELog) + sizeof(U64)) #define ZSTD_BUILD_FSE_TABLE_WKSP_SIZE_U32 ((ZSTD_BUILD_FSE_TABLE_WKSP_SIZE + sizeof(U32) - 1) / sizeof(U32)) +#define ZSTD_HUFFDTABLE_CAPACITY_LOG 12 typedef struct { ZSTD_seqSymbol LLTable[SEQSYMBOL_TABLE_SIZE(LLFSELog)]; /* Note : Space reserved for FSE Tables */ ZSTD_seqSymbol OFTable[SEQSYMBOL_TABLE_SIZE(OffFSELog)]; /* is also used as temporary workspace while building hufTable during DDict creation */ ZSTD_seqSymbol MLTable[SEQSYMBOL_TABLE_SIZE(MLFSELog)]; /* and therefore must be at least HUF_DECOMPRESS_WORKSPACE_SIZE large */ - HUF_DTable hufTable[HUF_DTABLE_SIZE(HufLog)]; /* can accommodate HUF_decompress4X */ + HUF_DTable hufTable[HUF_DTABLE_SIZE(ZSTD_HUFFDTABLE_CAPACITY_LOG)]; /* can accommodate HUF_decompress4X */ U32 rep[ZSTD_REP_NUM]; U32 workspace[ZSTD_BUILD_FSE_TABLE_WKSP_SIZE_U32]; } ZSTD_entropyDTables_t; @@ -135,7 +137,7 @@ struct ZSTD_DCtx_s const void* virtualStart; /* virtual start of previous segment if it was just before current one */ const void* dictEnd; /* end of previous segment */ size_t expected; - ZSTD_frameHeader fParams; + ZSTD_FrameHeader fParams; U64 processedCSize; U64 decodedSize; blockType_e bType; /* used in ZSTD_decompressContinue(), store blockType between block header decoding and block decompression stages */ @@ -152,7 +154,8 @@ struct ZSTD_DCtx_s size_t litSize; size_t rleSize; size_t staticSize; -#if DYNAMIC_BMI2 != 0 + int isFrameDecompression; +#if DYNAMIC_BMI2 int bmi2; /* == 1 if the CPU supports BMI2 and 0 otherwise. CPU support is determined dynamically once per context lifetime. */ #endif @@ -164,6 +167,8 @@ struct ZSTD_DCtx_s ZSTD_dictUses_e dictUses; ZSTD_DDictHashSet* ddictSet; /* Hash set for multiple ddicts */ ZSTD_refMultipleDDicts_e refMultipleDDicts; /* User specified: if == 1, will allow references to multiple DDicts. Default == 0 (disabled) */ + int disableHufAsm; + int maxBlockSizeParam; /* streaming */ ZSTD_dStreamStage streamStage; @@ -199,11 +204,11 @@ struct ZSTD_DCtx_s }; /* typedef'd to ZSTD_DCtx within "zstd.h" */ MEM_STATIC int ZSTD_DCtx_get_bmi2(const struct ZSTD_DCtx_s *dctx) { -#if DYNAMIC_BMI2 != 0 - return dctx->bmi2; +#if DYNAMIC_BMI2 + return dctx->bmi2; #else (void)dctx; - return 0; + return 0; #endif } diff --git a/lib/zstd/decompress_sources.h b/lib/zstd/decompress_sources.h index a06ca187aab5..8a47eb2a4514 100644 --- a/lib/zstd/decompress_sources.h +++ b/lib/zstd/decompress_sources.h @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the diff --git a/lib/zstd/zstd_common_module.c b/lib/zstd/zstd_common_module.c index 22686e367e6f..466828e35752 100644 --- a/lib/zstd/zstd_common_module.c +++ b/lib/zstd/zstd_common_module.c @@ -1,6 +1,6 @@ // SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -24,9 +24,6 @@ EXPORT_SYMBOL_GPL(HUF_readStats_wksp); EXPORT_SYMBOL_GPL(ZSTD_isError); EXPORT_SYMBOL_GPL(ZSTD_getErrorName); EXPORT_SYMBOL_GPL(ZSTD_getErrorCode); -EXPORT_SYMBOL_GPL(ZSTD_customMalloc); -EXPORT_SYMBOL_GPL(ZSTD_customCalloc); -EXPORT_SYMBOL_GPL(ZSTD_customFree); MODULE_LICENSE("Dual BSD/GPL"); MODULE_DESCRIPTION("Zstd Common"); diff --git a/lib/zstd/zstd_compress_module.c b/lib/zstd/zstd_compress_module.c index bd8784449b31..7651b53551c8 100644 --- a/lib/zstd/zstd_compress_module.c +++ b/lib/zstd/zstd_compress_module.c @@ -1,6 +1,6 @@ // SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -16,6 +16,7 @@ #include "common/zstd_deps.h" #include "common/zstd_internal.h" +#include "compress/zstd_compress_internal.h" #define ZSTD_FORWARD_IF_ERR(ret) \ do { \ @@ -92,12 +93,64 @@ zstd_compression_parameters zstd_get_cparams(int level, } EXPORT_SYMBOL(zstd_get_cparams); +size_t zstd_cctx_set_param(zstd_cctx *cctx, ZSTD_cParameter param, int value) +{ + return ZSTD_CCtx_setParameter(cctx, param, value); +} +EXPORT_SYMBOL(zstd_cctx_set_param); + size_t zstd_cctx_workspace_bound(const zstd_compression_parameters *cparams) { return ZSTD_estimateCCtxSize_usingCParams(*cparams); } EXPORT_SYMBOL(zstd_cctx_workspace_bound); +// Used by zstd_cctx_workspace_bound_with_ext_seq_prod() +static size_t dummy_external_sequence_producer( + void *sequenceProducerState, + ZSTD_Sequence *outSeqs, size_t outSeqsCapacity, + const void *src, size_t srcSize, + const void *dict, size_t dictSize, + int compressionLevel, + size_t windowSize) +{ + (void)sequenceProducerState; + (void)outSeqs; (void)outSeqsCapacity; + (void)src; (void)srcSize; + (void)dict; (void)dictSize; + (void)compressionLevel; + (void)windowSize; + return ZSTD_SEQUENCE_PRODUCER_ERROR; +} + +static void init_cctx_params_from_compress_params( + ZSTD_CCtx_params *cctx_params, + const zstd_compression_parameters *compress_params) +{ + ZSTD_parameters zstd_params; + memset(&zstd_params, 0, sizeof(zstd_params)); + zstd_params.cParams = *compress_params; + ZSTD_CCtxParams_init_advanced(cctx_params, zstd_params); +} + +size_t zstd_cctx_workspace_bound_with_ext_seq_prod(const zstd_compression_parameters *compress_params) +{ + ZSTD_CCtx_params cctx_params; + init_cctx_params_from_compress_params(&cctx_params, compress_params); + ZSTD_CCtxParams_registerSequenceProducer(&cctx_params, NULL, dummy_external_sequence_producer); + return ZSTD_estimateCCtxSize_usingCCtxParams(&cctx_params); +} +EXPORT_SYMBOL(zstd_cctx_workspace_bound_with_ext_seq_prod); + +size_t zstd_cstream_workspace_bound_with_ext_seq_prod(const zstd_compression_parameters *compress_params) +{ + ZSTD_CCtx_params cctx_params; + init_cctx_params_from_compress_params(&cctx_params, compress_params); + ZSTD_CCtxParams_registerSequenceProducer(&cctx_params, NULL, dummy_external_sequence_producer); + return ZSTD_estimateCStreamSize_usingCCtxParams(&cctx_params); +} +EXPORT_SYMBOL(zstd_cstream_workspace_bound_with_ext_seq_prod); + zstd_cctx *zstd_init_cctx(void *workspace, size_t workspace_size) { if (workspace == NULL) @@ -209,5 +262,25 @@ size_t zstd_end_stream(zstd_cstream *cstream, zstd_out_buffer *output) } EXPORT_SYMBOL(zstd_end_stream); +void zstd_register_sequence_producer( + zstd_cctx *cctx, + void* sequence_producer_state, + zstd_sequence_producer_f sequence_producer +) { + ZSTD_registerSequenceProducer(cctx, sequence_producer_state, sequence_producer); +} +EXPORT_SYMBOL(zstd_register_sequence_producer); + +size_t zstd_compress_sequences_and_literals(zstd_cctx *cctx, void* dst, size_t dst_capacity, + const zstd_sequence *in_seqs, size_t in_seqs_size, + const void* literals, size_t lit_size, size_t lit_capacity, + size_t decompressed_size) +{ + return ZSTD_compressSequencesAndLiterals(cctx, dst, dst_capacity, in_seqs, + in_seqs_size, literals, lit_size, + lit_capacity, decompressed_size); +} +EXPORT_SYMBOL(zstd_compress_sequences_and_literals); + MODULE_LICENSE("Dual BSD/GPL"); MODULE_DESCRIPTION("Zstd Compressor"); diff --git a/lib/zstd/zstd_decompress_module.c b/lib/zstd/zstd_decompress_module.c index 469fc3059be0..0ae819f0c927 100644 --- a/lib/zstd/zstd_decompress_module.c +++ b/lib/zstd/zstd_decompress_module.c @@ -1,6 +1,6 @@ // SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -113,7 +113,7 @@ EXPORT_SYMBOL(zstd_init_dstream); size_t zstd_reset_dstream(zstd_dstream *dstream) { - return ZSTD_resetDStream(dstream); + return ZSTD_DCtx_reset(dstream, ZSTD_reset_session_only); } EXPORT_SYMBOL(zstd_reset_dstream); -- 2.49.0.391.g4bbb303af6