aboutsummaryrefslogtreecommitdiff
path: root/core/math/big/internal.odin
diff options
context:
space:
mode:
authorJeroen van Rijn <Kelimion@users.noreply.github.com>2021-08-07 16:52:04 +0200
committerJeroen van Rijn <Kelimion@users.noreply.github.com>2021-08-11 20:59:53 +0200
commit62dcccd7ef21ba4c0049f2051bf1cd15433fef8d (patch)
treecfca19d9a22a6984d1c3beefcebf753753346b93 /core/math/big/internal.odin
parente288a563e1cbe4432b457c6891bf86abfec27412 (diff)
big: Move division internals.
Diffstat (limited to 'core/math/big/internal.odin')
-rw-r--r--core/math/big/internal.odin357
1 files changed, 349 insertions, 8 deletions
diff --git a/core/math/big/internal.odin b/core/math/big/internal.odin
index 8505f6f31..a3e2548db 100644
--- a/core/math/big/internal.odin
+++ b/core/math/big/internal.odin
@@ -608,7 +608,7 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc
/* Fast comba? */
// err = s_mp_sqr_comba(a, c);
} else {
- err = _int_sqr(dest, src);
+ err = _private_int_sqr(dest, src);
}
} else {
/*
@@ -680,14 +680,13 @@ internal_int_divmod :: proc(quotient, remainder, numerator, denominator: ^Int, a
// err = _int_div_recursive(quotient, remainder, numerator, denominator);
} else {
when true {
- err = _int_div_school(quotient, remainder, numerator, denominator);
+ err = _private_int_div_school(quotient, remainder, numerator, denominator);
} else {
/*
- NOTE(Jeroen): We no longer need or use `_int_div_small`.
+ NOTE(Jeroen): We no longer need or use `_private_int_div_small`.
We'll keep it around for a bit until we're reasonably certain div_school is bug free.
- err = _int_div_small(quotient, remainder, numerator, denominator);
*/
- err = _int_div_small(quotient, remainder, numerator, denominator);
+ err = _private_int_div_small(quotient, remainder, numerator, denominator);
}
}
return;
@@ -744,7 +743,7 @@ internal_int_divmod_digit :: proc(quotient, numerator: ^Int, denominator: DIGIT)
Three?
*/
if denominator == 3 {
- return _int_div_3(quotient, numerator);
+ return _private_int_div_3(quotient, numerator);
}
/*
@@ -1049,8 +1048,6 @@ _private_int_mul_comba :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
/*
Now extract the previous digit [below the carry].
*/
- // for ix = 0; ix < pa; ix += 1 { dest.digit[ix] = W[ix]; }
-
copy_slice(dest.digit[0:], W[:pa]);
/*
@@ -1065,6 +1062,350 @@ _private_int_mul_comba :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
return clamp(dest);
}
+/*
+ Low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16
+*/
+_private_int_sqr :: proc(dest, src: ^Int) -> (err: Error) {
+ pa := src.used;
+
+ t := &Int{}; ix, iy: int;
+ /*
+ Grow `t` to maximum needed size, or `_DEFAULT_DIGIT_COUNT`, whichever is bigger.
+ */
+ if err = grow(t, max((2 * pa) + 1, _DEFAULT_DIGIT_COUNT)); err != nil { return err; }
+ t.used = (2 * pa) + 1;
+
+ #no_bounds_check for ix = 0; ix < pa; ix += 1 {
+ carry := DIGIT(0);
+ /*
+ First calculate the digit at 2*ix; calculate double precision result.
+ */
+ r := _WORD(t.digit[ix+ix]) + (_WORD(src.digit[ix]) * _WORD(src.digit[ix]));
+
+ /*
+ Store lower part in result.
+ */
+ t.digit[ix+ix] = DIGIT(r & _WORD(_MASK));
+ /*
+ Get the carry.
+ */
+ carry = DIGIT(r >> _DIGIT_BITS);
+
+ #no_bounds_check for iy = ix + 1; iy < pa; iy += 1 {
+ /*
+ First calculate the product.
+ */
+ r = _WORD(src.digit[ix]) * _WORD(src.digit[iy]);
+
+ /* Now calculate the double precision result. NĂ³te we use
+ * addition instead of *2 since it's easier to optimize
+ */
+ r = _WORD(t.digit[ix+iy]) + r + r + _WORD(carry);
+
+ /*
+ Store lower part.
+ */
+ t.digit[ix+iy] = DIGIT(r & _WORD(_MASK));
+
+ /*
+ Get carry.
+ */
+ carry = DIGIT(r >> _DIGIT_BITS);
+ }
+ /*
+ Propagate upwards.
+ */
+ #no_bounds_check for carry != 0 {
+ r = _WORD(t.digit[ix+iy]) + _WORD(carry);
+ t.digit[ix+iy] = DIGIT(r & _WORD(_MASK));
+ carry = DIGIT(r >> _WORD(_DIGIT_BITS));
+ iy += 1;
+ }
+ }
+
+ err = clamp(t);
+ swap(dest, t);
+ destroy(t);
+ return err;
+}
+
+/*
+ Divide by three (based on routine from MPI and the GMP manual).
+*/
+_private_int_div_3 :: proc(quotient, numerator: ^Int) -> (remainder: DIGIT, err: Error) {
+ /*
+ b = 2^_DIGIT_BITS / 3
+ */
+ b := _WORD(1) << _WORD(_DIGIT_BITS) / _WORD(3);
+
+ q := &Int{};
+ if err = grow(q, numerator.used); err != nil { return 0, err; }
+ q.used = numerator.used;
+ q.sign = numerator.sign;
+
+ w, t: _WORD;
+ #no_bounds_check for ix := numerator.used; ix >= 0; ix -= 1 {
+ w = (w << _WORD(_DIGIT_BITS)) | _WORD(numerator.digit[ix]);
+ if w >= 3 {
+ /*
+ Multiply w by [1/3].
+ */
+ t = (w * b) >> _WORD(_DIGIT_BITS);
+
+ /*
+ Now subtract 3 * [w/3] from w, to get the remainder.
+ */
+ w -= t+t+t;
+
+ /*
+ Fixup the remainder as required since the optimization is not exact.
+ */
+ for w >= 3 {
+ t += 1;
+ w -= 3;
+ }
+ } else {
+ t = 0;
+ }
+ q.digit[ix] = DIGIT(t);
+ }
+ remainder = DIGIT(w);
+
+ /*
+ [optional] store the quotient.
+ */
+ if quotient != nil {
+ err = clamp(q);
+ swap(q, quotient);
+ }
+ destroy(q);
+ return remainder, nil;
+}
+
+/*
+ Signed Integer Division
+
+ c*b + d == a [i.e. a/b, c=quotient, d=remainder], HAC pp.598 Algorithm 14.20
+
+ Note that the description in HAC is horribly incomplete.
+ For example, it doesn't consider the case where digits are removed from 'x' in
+ the inner loop.
+
+ It also doesn't consider the case that y has fewer than three digits, etc.
+ The overall algorithm is as described as 14.20 from HAC but fixed to treat these cases.
+*/
+_private_int_div_school :: proc(quotient, remainder, numerator, denominator: ^Int) -> (err: Error) {
+ // if err = error_if_immutable(quotient, remainder); err != nil { return err; }
+ // if err = clear_if_uninitialized(quotient, numerator, denominator); err != nil { return err; }
+
+ q, x, y, t1, t2 := &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
+ defer destroy(q, x, y, t1, t2);
+
+ if err = grow(q, numerator.used + 2); err != nil { return err; }
+ q.used = numerator.used + 2;
+
+ if err = init_multi(t1, t2); err != nil { return err; }
+ if err = copy(x, numerator); err != nil { return err; }
+ if err = copy(y, denominator); err != nil { return err; }
+
+ /*
+ Fix the sign.
+ */
+ neg := numerator.sign != denominator.sign;
+ x.sign = .Zero_or_Positive;
+ y.sign = .Zero_or_Positive;
+
+ /*
+ Normalize both x and y, ensure that y >= b/2, [b == 2**MP_DIGIT_BIT]
+ */
+ norm, _ := count_bits(y);
+ norm %= _DIGIT_BITS;
+
+ if norm < _DIGIT_BITS - 1 {
+ norm = (_DIGIT_BITS - 1) - norm;
+ if err = shl(x, x, norm); err != nil { return err; }
+ if err = shl(y, y, norm); err != nil { return err; }
+ } else {
+ norm = 0;
+ }
+
+ /*
+ Note: HAC does 0 based, so if used==5 then it's 0,1,2,3,4, i.e. use 4
+ */
+ n := x.used - 1;
+ t := y.used - 1;
+
+ /*
+ while (x >= y*b**n-t) do { q[n-t] += 1; x -= y*b**{n-t} }
+ y = y*b**{n-t}
+ */
+
+ if err = shl_digit(y, n - t); err != nil { return err; }
+
+ c, _ := cmp(x, y);
+ for c != -1 {
+ q.digit[n - t] += 1;
+ if err = sub(x, x, y); err != nil { return err; }
+ c, _ = cmp(x, y);
+ }
+
+ /*
+ Reset y by shifting it back down.
+ */
+ shr_digit(y, n - t);
+
+ /*
+ Step 3. for i from n down to (t + 1).
+ */
+ #no_bounds_check for i := n; i >= (t + 1); i -= 1 {
+ if (i > x.used) { continue; }
+
+ /*
+ step 3.1 if xi == yt then set q{i-t-1} to b-1, otherwise set q{i-t-1} to (xi*b + x{i-1})/yt
+ */
+ if x.digit[i] == y.digit[t] {
+ q.digit[(i - t) - 1] = 1 << (_DIGIT_BITS - 1);
+ } else {
+
+ tmp := _WORD(x.digit[i]) << _DIGIT_BITS;
+ tmp |= _WORD(x.digit[i - 1]);
+ tmp /= _WORD(y.digit[t]);
+ if tmp > _WORD(_MASK) {
+ tmp = _WORD(_MASK);
+ }
+ q.digit[(i - t) - 1] = DIGIT(tmp & _WORD(_MASK));
+ }
+
+ /* while (q{i-t-1} * (yt * b + y{t-1})) >
+ xi * b**2 + xi-1 * b + xi-2
+
+ do q{i-t-1} -= 1;
+ */
+
+ iter := 0;
+
+ q.digit[(i - t) - 1] = (q.digit[(i - t) - 1] + 1) & _MASK;
+ #no_bounds_check for {
+ q.digit[(i - t) - 1] = (q.digit[(i - t) - 1] - 1) & _MASK;
+
+ /*
+ Find left hand.
+ */
+ zero(t1);
+ t1.digit[0] = ((t - 1) < 0) ? 0 : y.digit[t - 1];
+ t1.digit[1] = y.digit[t];
+ t1.used = 2;
+ if err = mul(t1, t1, q.digit[(i - t) - 1]); err != nil { return err; }
+
+ /*
+ Find right hand.
+ */
+ t2.digit[0] = ((i - 2) < 0) ? 0 : x.digit[i - 2];
+ t2.digit[1] = x.digit[i - 1]; /* i >= 1 always holds */
+ t2.digit[2] = x.digit[i];
+ t2.used = 3;
+
+ if t1_t2, _ := cmp_mag(t1, t2); t1_t2 != 1 {
+ break;
+ }
+ iter += 1; if iter > 100 { return .Max_Iterations_Reached; }
+ }
+
+ /*
+ Step 3.3 x = x - q{i-t-1} * y * b**{i-t-1}
+ */
+ if err = int_mul_digit(t1, y, q.digit[(i - t) - 1]); err != nil { return err; }
+ if err = shl_digit(t1, (i - t) - 1); err != nil { return err; }
+ if err = sub(x, x, t1); err != nil { return err; }
+
+ /*
+ if x < 0 then { x = x + y*b**{i-t-1}; q{i-t-1} -= 1; }
+ */
+ if x.sign == .Negative {
+ if err = copy(t1, y); err != nil { return err; }
+ if err = shl_digit(t1, (i - t) - 1); err != nil { return err; }
+ if err = add(x, x, t1); err != nil { return err; }
+
+ q.digit[(i - t) - 1] = (q.digit[(i - t) - 1] - 1) & _MASK;
+ }
+ }
+
+ /*
+ Now q is the quotient and x is the remainder, [which we have to normalize]
+ Get sign before writing to c.
+ */
+ z, _ := is_zero(x);
+ x.sign = .Zero_or_Positive if z else numerator.sign;
+
+ if quotient != nil {
+ clamp(q);
+ swap(q, quotient);
+ quotient.sign = .Negative if neg else .Zero_or_Positive;
+ }
+
+ if remainder != nil {
+ if err = shr(x, x, norm); err != nil { return err; }
+ swap(x, remainder);
+ }
+
+ return nil;
+}
+
+/*
+ Slower bit-bang division... also smaller.
+*/
+@(deprecated="Use `_int_div_school`, it's 3.5x faster.")
+_private_int_div_small :: proc(quotient, remainder, numerator, denominator: ^Int) -> (err: Error) {
+
+ ta, tb, tq, q := &Int{}, &Int{}, &Int{}, &Int{};
+ c: int;
+
+ goto_end: for {
+ if err = one(tq); err != nil { break goto_end; }
+
+ num_bits, _ := count_bits(numerator);
+ den_bits, _ := count_bits(denominator);
+ n := num_bits - den_bits;
+
+ if err = abs(ta, numerator); err != nil { break goto_end; }
+ if err = abs(tb, denominator); err != nil { break goto_end; }
+ if err = shl(tb, tb, n); err != nil { break goto_end; }
+ if err = shl(tq, tq, n); err != nil { break goto_end; }
+
+ for n >= 0 {
+ if c, _ = cmp_mag(ta, tb); c == 0 || c == 1 {
+ // ta -= tb
+ if err = sub(ta, ta, tb); err != nil { break goto_end; }
+ // q += tq
+ if err = add( q, q, tq); err != nil { break goto_end; }
+ }
+ if err = shr1(tb, tb); err != nil { break goto_end; }
+ if err = shr1(tq, tq); err != nil { break goto_end; }
+
+ n -= 1;
+ }
+
+ /*
+ Now q == quotient and ta == remainder.
+ */
+ neg := numerator.sign != denominator.sign;
+ if quotient != nil {
+ swap(quotient, q);
+ z, _ := is_zero(quotient);
+ quotient.sign = .Negative if neg && !z else .Zero_or_Positive;
+ }
+ if remainder != nil {
+ swap(remainder, ta);
+ z, _ := is_zero(numerator);
+ remainder.sign = .Zero_or_Positive if z else numerator.sign;
+ }
+
+ break goto_end;
+ }
+ destroy(ta, tb, tq, q);
+ return err;
+}
+
/*