Ryan Rueger

ryan@rueg.re / picture / key / home
aboutsummaryrefslogtreecommitdiffhomepage
path: root/lcsidh.py
blob: 9e2d91e21232fad35288dfa2ad460950178ae383 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
#!/usr/bin/env python3

from sage.all import proof
from sage.arith.misc import gcd, valuation

from coin import sos_coins
from ideals import ideal_to_sage, short_ideals
from utilities import remove_common_factors

import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
logger_sh = logging.StreamHandler()
logger_sh.setLevel(logging.WARNING)
formatter = logging.Formatter("%(name)s [%(levelname)s] %(message)s")
logger_sh.setFormatter(formatter)
logger.addHandler(logger_sh)

proof.all(False)


def ascending_combinations(max_norm):
    """Sequence of distinct non-negative integer pairs in ascending L1-norm up to permutation

    The pairs returned (x, y) satisfy x >= y

    Arguments:
      - max_norm: Return pairs with L1 norm strictly less than max_norm

    Note: We return pairs with L1-norm up to (but not including) max_norm for compatibility with previous
    implementatation.
    """

    norm = 0
    while norm < max_norm:
        for i in range(norm):
            # Pair `norm - i, i` has L1-norm `norm`
            # We want pairs up to permutation, so we only return if `norm - i < i`
            if norm - i > i:
                yield norm - i, i
        norm += 1


def two_signature(ideal, twoval, uv_params):
    """Return signature of norm-2-ideals contained in ideal

    Input:
      - ideal: Ideal as list of generators, where a generator is
               a pair of numbers [a, b] representing the element a + ib
      - twoval: The 2-valuation of the norm of the input ideal
      - uv_params: Instance of uv_params class for ambient order information

    Ouput: Tuple of integers

        (left - min(left, right), right - min(left, right))

    Where two_left**left divides ideal and two_right**right divides ideal

    Note: One entry of the output tuple is zero.
    """

    ideal_sage = ideal_to_sage(ideal, uv_params.max_order)

    # Compute common part
    ideal_common_two_left = ideal_sage + uv_params.two_left**twoval
    ideal_common_two_right = ideal_sage + uv_params.two_right**twoval

    # Get the power that actually was common
    twoval_left = valuation(ideal_common_two_left.norm(), 2)
    twoval_right = valuation(ideal_common_two_right.norm(), 2)

    # Now we know two_left divides ideal twoval_left times (same for right)
    # If both two_left and two_right divides ideal, then ideal contains the
    # ideal corresponding to multiplication_by_two
    # We assume this will be factored later on, so we remove it from our
    # signature
    mult_by_two = min(twoval_left, twoval_right)

    return twoval_left - mult_by_two, twoval_right - mult_by_two


def UVSolver(ideal, ideal_norm, uv_params):
    """Solve UV equation for given ideal

    Given an ideal ideal try to find two equivalent ideals I_1 and I_2 such that,
    after writing n(I_1) = N1*cof1 and n(I_2) = N2*cof2 we can solve the
    equation
        u*N1 + v*N2 = 2^e
    for u and v sums of two squares.

    Input:
     - ideal:      the input ideal as sage ideal
     - ideal_norm: the (ideal) norm of ideal;
     - uv_params: an instance of uv_params.UV_params containing the parameters
                  for the algorithm, as described in the UV_params class.

                  Note: e from the 2^e in the norm equation above is given by
                  uv_params

    Output:
      - u, v, I_1, I_2, sig_1, sig_2, t

    where
      - u, v are written as
            (x_u, y_u, g_u)
        such that g_u*(x_u^2 + y_2^2)=u

        Note: if uv_params.n_squares=1, only one among u and v is decomposed as
        sum of squares.

      - I_1, I_2 are ideals represented as
            (norm, cofactor, 2-valuation, gens)
        where
            norm * cofactor * 2**(2-valuation)
        is the norm of the ideal generated by gens;

      - setting
            n1 = I_1[0], n2 = I_2[0] (the cofactor-free norms)
        it holds
            2**t == u*n1 + v*n2
        in general t <= uv_params.e, but they do not need to be equal;

      - sig_1, sig_2 are the ideals over 2 that are left in I_1, I_2
        respectively after multiplications-by-two are removed, as a pair (left,
        right).

        One of `left` and `right` must be zero; if they point in the same
        direction (e. g. (e_1, 0) and (e_2, 0)) the algorithm assumes that the
        common part is computed in advance and removed by the ideals.
    """
    # Look for short vectors
    # Elements of L are of the form (norm, [elements])
    p, spr = uv_params.p, uv_params.spr
    norm_bound = uv_params.norm_bound
    L = short_ideals(ideal, ideal_norm, norm_bound * spr, p, spr, uv_params.norm_cff_bound)
    logger.info(f"{len(L) = }")

    # Remove allowed primes
    ideals = {}
    for nI, I in L:
        if gcd(nI, ideal_norm) > 1:
            continue
        # Even exponents
        v2 = nI.valuation(2)
        nI //= 2**v2

        smooth_part, rough_cofactor = remove_common_factors(nI, uv_params.allowed_primes_prod)

        if smooth_part in ideals and rough_cofactor * 2**v2 >= ideals[smooth_part][0] * 2 ** ideals[smooth_part][1]:
            # Multiple of a previous vector
            continue
        ideals[smooth_part] = (rough_cofactor, v2, I)

    # Short vectors sorted by cofactor-free norm
    ideals = [(j,) + ideals[j] for j in ideals]
    ideals.sort()

    cached_two_signatures = {}

    # Keep track of the best solution found so far
    best_sol = None
    best_cost = None
    sol_count = 0

    # Try pairs to solve the coin equation
    for i, j in ascending_combinations(min(len(ideals), uv_params.comb_bound)):

        if uv_params.sol_bound and sol_count >= uv_params.sol_bound:
            break

        norm_1, _, v2_1, ideal_1 = ideals[i]
        norm_2, _, v2_2, ideal_2 = ideals[j]

        if gcd(norm_1, norm_2) != 1:
            continue

        # Compute the part above two of the ideals
        if i in cached_two_signatures:
            signature_1 = cached_two_signatures[i]
        else:
            signature_1 = two_signature(ideal_1, v2_1, uv_params)
            # Update cache
            cached_two_signatures[i] = signature_1

        if j in cached_two_signatures:
            signature_2 = cached_two_signatures[j]
        else:
            signature_2 = two_signature(ideal_2, v2_2, uv_params)
            # Update cache
            cached_two_signatures[j] = signature_2

        # Compute the remaining torsion
        used_torsion = abs(signature_1[0] - signature_2[0]) + abs(signature_1[1] - signature_2[1])
        remaining_torsion = uv_params.uv_e - used_torsion

        # Avoid generating too many combinations for a single pair
        local_sol_count = 0

        for (_, x_u, y_u, g_u), (_, x_v, y_v, g_v), t in sos_coins(
            norm_1, norm_2, remaining_torsion, uv_params.sos_allowed_primes, n_squares=uv_params.n_squares
        ):

            assert g_u * (x_u ** 2 + y_u ** 2) * norm_1 + g_v * (x_v ** 2 + y_v ** 2) * norm_2 == 2**t

            sol_count += 1
            local_sol_count += 1

            if uv_params.sol_bound == 1:
                return (x_u, y_u, g_u), (x_v, y_v, g_v), ideals[i], ideals[j], signature_1, signature_2, t

            # Cost estimate for best solution
            elkies_steps = []

            # Append cost of elkies steps
            # Recall: ideals[i] is a tuple (norm, cofactor, twoval, ideal)
            # The cofactor is smooth, and constains the product of norms of elkies steps
            elkies_steps.append(ideals[i][1])
            elkies_steps.append(ideals[j][1])

            elkies_steps.append(g_u)
            elkies_steps.append(g_v)

            # Recall: Elkies steps of norm prime are done for all primes in uv_params.allowed_primes
            # We assume cost to be linear in the size of the prime
            # We count the number of steps required by getting the sum of valuations for each ideal
            elkies_cost = sum(
                [prime * sum([valuation(value, prime) for value in elkies_steps]) for prime in uv_params.allowed_primes]
            )

            if not best_sol or not best_cost or elkies_cost < best_cost:
                best_sol = (x_u, y_u, g_u), (x_v, y_v, g_v), ideals[i], ideals[j], signature_1, signature_2, t
                best_cost = elkies_cost

            if sol_count >= uv_params.sol_bound:
                break

            # @TODO: this is completely arbitrary
            if local_sol_count > 3:
                break

    return best_sol