An elementary example of calculating private and plublic keys for Elliptic curve cryptography (ECC). Part of a network security class.
A very simple example to show some concepts of elliptic cryptography. It demonstrates:
import math
from random import randint
import numpy as np
import matplotlib.pyplot as plt
Accepts coefficients and modulo of an elliptic curve as input
This could have been done better
class Curve:
# y^2 = x^3 + ax + b (mod p)
def __init__(self, a, b, p):
self.a = a
self.b = b
self.p = p
Accepts the coordinate of an elliptic curve point as points
e.g. a pont P has coords (x,y)
class Point:
def __init__(self, x, y):
self.x = x
self.y = y
#Check if this point is equal to another point p
#compares the x and y values of each point
def is_equal(self, p):
return (self.x == p.x and self.y == p.y)
#Check if this point is at infinity
def is_infinity(self):
#print(self.x)
return (math.isinf(self.x) or math.isinf(self.y))
Generates a private key
Accepts a Point p and Curve c
Returns a psuedo random private key between 1 and p-1 where p is the modulo system of the curve
def getECCPrivateKey(p,c):
return randint(2, c.p - 2 )
Performs point doubling
Accepts 2 Points object p1, p2 and a Curve c
This function performs doubling of p1. Note p1 and p2 should be the same point
Returns a point p3 which is the result of point doubling of p1.
def pointDoubling(c, p1, p2):
p3 = Point(math.inf, math.inf)
if (p1.is_infinity()==True or p2.is_infinity() ==True):
return p3
if (p1.is_equal(p2)):
m = (3 * p1.x**2 + c.a) * modinv(2 * p1.y, c.p)
m = m % c.p
#print("m = {0}".format(m))
if (not math.isinf(m)):
p3.x = (m**2 - p1.x - p2.x) % c.p
p3.y = (m*(p1.x-p3.x) - p1.y) % c.p
return p3
Create public key from elliptic curve
Accepts a Curve c, Point object p, Private key k
Returns the public key for the curve
def getECCPublicKey(c,p,k):
pubK = p
for i in range(1,k):
pubK = pointDoubling(c,pubK,pubK)
#print('x: {0}, y:{1}'.format(pubK.x, pubK.y))
return pubK
Implementation of extended euclidean
def egcd(a, p):
if a==0:
return (p, 0, 1)
else:
g, x, y = egcd(p%a, a)
return (g, y - (p//a)* x, x)
Find the multiplicative inverse of a(mod)p
def modinv(a, m):
g, x, y = egcd(a, m)
if g != 1:
raise ValueError
return x % m
Sample data
c = Curve(1,1,23) #a curve C defined as y = x^3 + x + 1 (mod 23)
p = Point(9,7) #a point P = (9,7)
#get random private key
k = getECCPrivateKey(p,c)
#print private key
print("Private Key = {0}".format(k))
#generate a public key
pubK = getECCPublicKey(c,p,k)
#print public key
print('Public key = ( x: {0}, y:{1} )'.format(pubK.x, pubK.y))
fig, ax = plt.subplots()
xlist = np.linspace(-10, 10, 100)
ylist = np.linspace(-10, 10, 100)
X,Y = np.meshgrid(xlist, ylist)
#y^2 = x^3 + ax + b (mod p)
plt.contour(X, Y, (pow(Y,2) - pow(X,3) - X * 1 - 1), [0])
plt.grid()
ax.axhline(y=0, color="k")
ax.axvline(x=0, color="k")
#plt.plot(9,7, 'g*')
plt.show()
p = 23
Calculate perfect squares mod p
#Calculate perfect squares mod p
p = 23
perfect_squares_input = []
perfect_squares = []
for i in range(1,p):
sq = i**2 % p
if not(sq in perfect_squares):
perfect_squares_input.append(i)
perfect_squares.append(sq)
print(perfect_squares_input)
print(perfect_squares)
Calculate points on the curve mod p
#calculate points on the curve
xpoints_on_curve = []
ypoints_on_curve = []
for i in range(p):
y_sqr = (pow(i,3) +i * 1 + 1)%23
if y_sqr in perfect_squares:
xpoints_on_curve.append(i)
xpoints_on_curve.append(i)
sqr_root = perfect_squares.index(y_sqr)+1
ypoints_on_curve.append(sqr_root % p)
ypoints_on_curve.append(-sqr_root% p)
print(xpoints_on_curve)
print(ypoints_on_curve)
Note the horizontal symmetry of the points.
plt.rcParams['figure.dpi'] = 100
fig2, ax2 = plt.subplots()
ax2.set_xticks(np.arange(0,23,1))
ax2.set_yticks(np.arange(0,23,1))
X2 = xpoints_on_curve
Y2 = ypoints_on_curve
plt.scatter(X2, Y2)
plt.grid()