Skip to content
Empty file.
56 changes: 56 additions & 0 deletions numerical/2019_02_pythagorean_triples/more_triples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@

# ref:
# https://en.wikipedia.org/wiki/Formulas_for_generating_Pythagorean_triples

from itertools import islice
from heapq import merge

from numba import njit


def take(n, iterable):
return list(islice(iterable, n))


@njit
def py_triples_stifel():
n = 1

while True:
denom = n * 2 + 1
improper_numerator = n * (denom + 1)

yield denom, improper_numerator, improper_numerator + 1

n += 1


@njit
def py_triples_ozanam():
n = 1

while True:
denom = 4 * (n + 1)
improper_numerator = denom * (1 + n) - 1

yield denom, improper_numerator, improper_numerator + 2

n += 1


def py_triples_stifel_ozanam():
# all primitive triples of the Plato and Pythagoras families
return merge(py_triples_stifel(), py_triples_ozanam())


@njit
def py_triples_fibonacci():
k = 3

while True:
c_2 = k ** 2
n = (c_2 + 1) >> 1

yield k, (n - 1), n

k += 2
33 changes: 33 additions & 0 deletions numerical/2019_02_pythagorean_triples/test_more_triples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from pytest import mark

from more_triples import (
py_triples_stifel, py_triples_ozanam,
py_triples_fibonacci, take,
py_triples_stifel_ozanam
)

ALL_FNS = [
py_triples_stifel, py_triples_ozanam,
py_triples_fibonacci
]


def check(gen, first_n=100):
for a, b, c in take(first_n, gen):
assert a**2 + b**2 == c**2


@mark.parametrize('fn', ALL_FNS, ids=lambda x: f'{x.__name__}')
def test_py_triple(fn):
gen = fn()
check(gen)


def test_py_triples_fused_stifel_ozanam():
gen = py_triples_stifel_ozanam()
check(gen)


if __name__ == '__main__':
import pytest
pytest.main([__file__])
1 change: 0 additions & 1 deletion numerical/2019_02_pythagorean_triples/triples.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def fibonacci_triples():
odd_ints.append(odd)

if int(odd**0.5)**2 == int(odd) and len(odd_ints) >= 4:
n = (odd + 1) / 2
b2 = sum(odd_ints[:-1])
c2 = sum(odd_ints)
yield tuple(map(int, (sqrt(odd), sqrt(b2), sqrt(c2))))
Expand Down