Ryan Rueger

ryan@rueg.re / picture / key / home
aboutsummaryrefslogtreecommitdiffhomepage
path: root/lcsidh.py
blob: 30ff6ce202a396d97deef012879f668e227f6bda (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
from sage.all import(
    isqrt, sqrt, floor, prod,
    randint,
    log,
    proof,
    next_prime,
    kronecker,
    GF, ZZ, QQ,
    Matrix,
    gcd, xgcd, valuation,
)

from coin import sos_pair_generator, xsos_pair_generator
from ideals import ideal_to_sage, ShortIdeals
import time

import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
logger_sh = logging.StreamHandler()
logger_sh.setLevel(logging.WARNING)
formatter = logging.Formatter('%(name)s [%(levelname)s] %(message)s')
logger_sh.setFormatter(formatter)
logger.addHandler(logger_sh)

proof.all(False)


def generate_combs(n):
    """
    Generate pairs (i, j) sorted by L1-norm.
    """
    i_start = 1
    while i_start < n:
        i = i_start
        j = 0
        while i > j:
            yield i, j
            j += 1
            i -= 1
        i_start += 1


def eval_sol(vals, allowed_primes):
    """
    Given a list vals of Elkies isogenies of degrees contained in
    allowed_primes, estimate the total cost of such isogeny. The cost of a
    single step is assumed to be linear in the degree.
    """
    costs = {}
    for li in allowed_primes:
        cost_li = 0
        for vi in vals:
            cost_li += valuation(vi, li)
        costs[li] = cost_li

    return sum(li * costs[li] for li in allowed_primes)


def UVSolver(R, N_R, uv_params):
    """
    Given an ideal R try to find two equivalent ideals I_1 and I_2 such that,
    after writing n(I_1) = N1*cof1 and n(I_2) = N2*cof2 we can solve the
    equation
        u*N1 + v*N2 = 2^e
    for u and v sums of two squares.

    Input:
    - R: the input ideal as a list of generators, where a generator is a pair
        of numbers [a, b] representing the element a + ib;
    - N_R: the (ideal) norm of R;
    - uv_params: an instance of uv_params.UV_params containing the parameters
      for the algorithm, as described in the UV_params class.

    Output:
    - u, v, ids[i], ids[j], sig_1, sig_2, t
    where
    - u, v are written as (x_u, y_u, g_u) such that g_u*(x_u^2 + y_2^2)=u
    - ids[i], ids[j] are ideals represented as (norm, cofactor, 2-valuation,
      gens); norm * cofactor * 2**(2-valuation) is the norm if the ideal
      generated by gens;
    - setting n1=ids[i][0], n2 = ids[j][0] (the cofactor-free norms) it holds
      2**t == u*n1 + v*n2; in general t <= uv_params.e, but they do not need to
      be equal;
    - sig_1, sig_2 are the ideals over 2 that are left in ids[i], ids[j]
      respectively after multiplications by two are removed, as a pair (left,
      right); one between left and right must be zero; if they point in the
      same direction (e. g. (e_1, 0) and (e_2, 0)) the algorithm assumes that
      the common part is computed in advance and removed by the ideals.
    - if uv_params.n_squares=1, only one among u and v is decomposed as sum of
      squares
    """

    # Look for short vectors
    # Elements of L are of the form (norm, [elements])
    p, spr = uv_params.p, uv_params.spr
    norm_bound = uv_params.norm_bound
    L = ShortIdeals(R, N_R, norm_bound*spr, p, spr, uv_params.norm_cff_bound)
    logger.info(f'{len(L) = }')

    # Remove allowed primes
    ids = {}
    for nI, I in L:
        # Even exponents
        v2 = nI.valuation(2)
        nI //= 2**v2

        k = nI / gcd(nI, uv_params.allowed_primes_prod)
        while gcd(k, uv_params.allowed_primes_prod) != 1:
            k /= gcd(k, uv_params.allowed_primes_prod)

        cof = nI/k
        if k in ids and cof*2**v2 >= ids[k][0]*2**ids[k][1]:
            # Multiple of a previous vector
            continue
        ids[k] = (cof, v2, I)

    # Short vectors sorted by cofactor-free norm
    ids = [(j,) + ids[j] for j in ids]
    ids.sort()

    two_parts = {}
    two_left = ideal_to_sage([[2, 0], [-1/2, 1/2]], uv_params.max_order)
    two_right = ideal_to_sage([[2, 0], [1/2, 1/2]], uv_params.max_order)

    # Keep track of the best solution found so far
    best_sol = None
    best_cost = None
    sol_cnt = 0

    def compute_two_sig(id_i, v2_i):
        id_sage = ideal_to_sage(id_i, uv_params.max_order)
        id_l = id_sage + two_left**v2_i
        id_r = id_sage + two_right**v2_i
        v2_l = valuation(id_l.norm(), 2)
        v2_r = valuation(id_r.norm(), 2)
        mul2 = min(v2_l, v2_r)
        sig_i = (v2_l - mul2, v2_r - mul2)
        return sig_i

    # Try pairs to solve the coin equation
    for i, j in generate_combs(min(len(ids), uv_params.comb_bound)):

        if uv_params.sol_bound and sol_cnt >= uv_params.sol_bound:
            break
        n1, cof1, v2_1, id1 = ids[i]
        n2, cof2, v2_2, id2 = ids[j]
        if gcd(n1, n2) != 1:
            continue

        # Compute the part above two of the ideals
        if i in two_parts:
            sig_1 = two_parts[i]
        else:
            sig_1 = compute_two_sig(id1, v2_1)
            two_parts[i] = sig_1

        if j in two_parts:
            sig_2 = two_parts[j]
        else:
            sig_2 = compute_two_sig(id2, v2_2)
            two_parts[j] = sig_2

        # Compute the remeaning torsion
        used_torsion = abs(sig_1[0] - sig_2[0]) + abs(sig_1[1] - sig_2[1])
        av_torsion = uv_params.uv_e - used_torsion

        # Avoid generating too many combinations for a single pair
        local_sol_ncnt = 0
        for u, v, t in xsos_pair_generator(n1, n2, av_torsion,
                                           uv_params.sos_allowed_primes,
                                           n_squares=uv_params.n_squares):
            sol_cnt += 1
            local_sol_ncnt += 1

            if uv_params.sol_bound == 1:
                return u, v, ids[i], ids[j], sig_1, sig_2, t

            # Cost estimate for best solution
            elkies_vals = [ids[i][1], ids[j][1]]
            if type(u) in [tuple, list]:
                elkies_vals.append(u[2])
            if type(v) in [tuple, list]:
                elkies_vals.append(v[2])
            elkies_cost = eval_sol(elkies_vals, uv_params.allowed_primes)
            if not best_sol or elkies_cost < best_cost:
                best_sol = u, v, ids[i], ids[j], sig_1, sig_2, t
                best_cost = elkies_cost
            if sol_cnt >= uv_params.sol_bound:
                break
            if local_sol_ncnt > 3: # TODO: this is completely arbitrary
                break

    return best_sol