diff options
| author | Jeroen van Rijn <Kelimion@users.noreply.github.com> | 2021-08-07 16:52:04 +0200 |
|---|---|---|
| committer | Jeroen van Rijn <Kelimion@users.noreply.github.com> | 2021-08-11 20:59:53 +0200 |
| commit | 62dcccd7ef21ba4c0049f2051bf1cd15433fef8d (patch) | |
| tree | cfca19d9a22a6984d1c3beefcebf753753346b93 /core/math/big/internal.odin | |
| parent | e288a563e1cbe4432b457c6891bf86abfec27412 (diff) | |
big: Move division internals.
Diffstat (limited to 'core/math/big/internal.odin')
| -rw-r--r-- | core/math/big/internal.odin | 357 |
1 files changed, 349 insertions, 8 deletions
diff --git a/core/math/big/internal.odin b/core/math/big/internal.odin index 8505f6f31..a3e2548db 100644 --- a/core/math/big/internal.odin +++ b/core/math/big/internal.odin @@ -608,7 +608,7 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc /* Fast comba? */ // err = s_mp_sqr_comba(a, c); } else { - err = _int_sqr(dest, src); + err = _private_int_sqr(dest, src); } } else { /* @@ -680,14 +680,13 @@ internal_int_divmod :: proc(quotient, remainder, numerator, denominator: ^Int, a // err = _int_div_recursive(quotient, remainder, numerator, denominator); } else { when true { - err = _int_div_school(quotient, remainder, numerator, denominator); + err = _private_int_div_school(quotient, remainder, numerator, denominator); } else { /* - NOTE(Jeroen): We no longer need or use `_int_div_small`. + NOTE(Jeroen): We no longer need or use `_private_int_div_small`. We'll keep it around for a bit until we're reasonably certain div_school is bug free. - err = _int_div_small(quotient, remainder, numerator, denominator); */ - err = _int_div_small(quotient, remainder, numerator, denominator); + err = _private_int_div_small(quotient, remainder, numerator, denominator); } } return; @@ -744,7 +743,7 @@ internal_int_divmod_digit :: proc(quotient, numerator: ^Int, denominator: DIGIT) Three? */ if denominator == 3 { - return _int_div_3(quotient, numerator); + return _private_int_div_3(quotient, numerator); } /* @@ -1049,8 +1048,6 @@ _private_int_mul_comba :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) { /* Now extract the previous digit [below the carry]. */ - // for ix = 0; ix < pa; ix += 1 { dest.digit[ix] = W[ix]; } - copy_slice(dest.digit[0:], W[:pa]); /* @@ -1065,6 +1062,350 @@ _private_int_mul_comba :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) { return clamp(dest); } +/* + Low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 +*/ +_private_int_sqr :: proc(dest, src: ^Int) -> (err: Error) { + pa := src.used; + + t := &Int{}; ix, iy: int; + /* + Grow `t` to maximum needed size, or `_DEFAULT_DIGIT_COUNT`, whichever is bigger. + */ + if err = grow(t, max((2 * pa) + 1, _DEFAULT_DIGIT_COUNT)); err != nil { return err; } + t.used = (2 * pa) + 1; + + #no_bounds_check for ix = 0; ix < pa; ix += 1 { + carry := DIGIT(0); + /* + First calculate the digit at 2*ix; calculate double precision result. + */ + r := _WORD(t.digit[ix+ix]) + (_WORD(src.digit[ix]) * _WORD(src.digit[ix])); + + /* + Store lower part in result. + */ + t.digit[ix+ix] = DIGIT(r & _WORD(_MASK)); + /* + Get the carry. + */ + carry = DIGIT(r >> _DIGIT_BITS); + + #no_bounds_check for iy = ix + 1; iy < pa; iy += 1 { + /* + First calculate the product. + */ + r = _WORD(src.digit[ix]) * _WORD(src.digit[iy]); + + /* Now calculate the double precision result. NĂ³te we use + * addition instead of *2 since it's easier to optimize + */ + r = _WORD(t.digit[ix+iy]) + r + r + _WORD(carry); + + /* + Store lower part. + */ + t.digit[ix+iy] = DIGIT(r & _WORD(_MASK)); + + /* + Get carry. + */ + carry = DIGIT(r >> _DIGIT_BITS); + } + /* + Propagate upwards. + */ + #no_bounds_check for carry != 0 { + r = _WORD(t.digit[ix+iy]) + _WORD(carry); + t.digit[ix+iy] = DIGIT(r & _WORD(_MASK)); + carry = DIGIT(r >> _WORD(_DIGIT_BITS)); + iy += 1; + } + } + + err = clamp(t); + swap(dest, t); + destroy(t); + return err; +} + +/* + Divide by three (based on routine from MPI and the GMP manual). +*/ +_private_int_div_3 :: proc(quotient, numerator: ^Int) -> (remainder: DIGIT, err: Error) { + /* + b = 2^_DIGIT_BITS / 3 + */ + b := _WORD(1) << _WORD(_DIGIT_BITS) / _WORD(3); + + q := &Int{}; + if err = grow(q, numerator.used); err != nil { return 0, err; } + q.used = numerator.used; + q.sign = numerator.sign; + + w, t: _WORD; + #no_bounds_check for ix := numerator.used; ix >= 0; ix -= 1 { + w = (w << _WORD(_DIGIT_BITS)) | _WORD(numerator.digit[ix]); + if w >= 3 { + /* + Multiply w by [1/3]. + */ + t = (w * b) >> _WORD(_DIGIT_BITS); + + /* + Now subtract 3 * [w/3] from w, to get the remainder. + */ + w -= t+t+t; + + /* + Fixup the remainder as required since the optimization is not exact. + */ + for w >= 3 { + t += 1; + w -= 3; + } + } else { + t = 0; + } + q.digit[ix] = DIGIT(t); + } + remainder = DIGIT(w); + + /* + [optional] store the quotient. + */ + if quotient != nil { + err = clamp(q); + swap(q, quotient); + } + destroy(q); + return remainder, nil; +} + +/* + Signed Integer Division + + c*b + d == a [i.e. a/b, c=quotient, d=remainder], HAC pp.598 Algorithm 14.20 + + Note that the description in HAC is horribly incomplete. + For example, it doesn't consider the case where digits are removed from 'x' in + the inner loop. + + It also doesn't consider the case that y has fewer than three digits, etc. + The overall algorithm is as described as 14.20 from HAC but fixed to treat these cases. +*/ +_private_int_div_school :: proc(quotient, remainder, numerator, denominator: ^Int) -> (err: Error) { + // if err = error_if_immutable(quotient, remainder); err != nil { return err; } + // if err = clear_if_uninitialized(quotient, numerator, denominator); err != nil { return err; } + + q, x, y, t1, t2 := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}; + defer destroy(q, x, y, t1, t2); + + if err = grow(q, numerator.used + 2); err != nil { return err; } + q.used = numerator.used + 2; + + if err = init_multi(t1, t2); err != nil { return err; } + if err = copy(x, numerator); err != nil { return err; } + if err = copy(y, denominator); err != nil { return err; } + + /* + Fix the sign. + */ + neg := numerator.sign != denominator.sign; + x.sign = .Zero_or_Positive; + y.sign = .Zero_or_Positive; + + /* + Normalize both x and y, ensure that y >= b/2, [b == 2**MP_DIGIT_BIT] + */ + norm, _ := count_bits(y); + norm %= _DIGIT_BITS; + + if norm < _DIGIT_BITS - 1 { + norm = (_DIGIT_BITS - 1) - norm; + if err = shl(x, x, norm); err != nil { return err; } + if err = shl(y, y, norm); err != nil { return err; } + } else { + norm = 0; + } + + /* + Note: HAC does 0 based, so if used==5 then it's 0,1,2,3,4, i.e. use 4 + */ + n := x.used - 1; + t := y.used - 1; + + /* + while (x >= y*b**n-t) do { q[n-t] += 1; x -= y*b**{n-t} } + y = y*b**{n-t} + */ + + if err = shl_digit(y, n - t); err != nil { return err; } + + c, _ := cmp(x, y); + for c != -1 { + q.digit[n - t] += 1; + if err = sub(x, x, y); err != nil { return err; } + c, _ = cmp(x, y); + } + + /* + Reset y by shifting it back down. + */ + shr_digit(y, n - t); + + /* + Step 3. for i from n down to (t + 1). + */ + #no_bounds_check for i := n; i >= (t + 1); i -= 1 { + if (i > x.used) { continue; } + + /* + step 3.1 if xi == yt then set q{i-t-1} to b-1, otherwise set q{i-t-1} to (xi*b + x{i-1})/yt + */ + if x.digit[i] == y.digit[t] { + q.digit[(i - t) - 1] = 1 << (_DIGIT_BITS - 1); + } else { + + tmp := _WORD(x.digit[i]) << _DIGIT_BITS; + tmp |= _WORD(x.digit[i - 1]); + tmp /= _WORD(y.digit[t]); + if tmp > _WORD(_MASK) { + tmp = _WORD(_MASK); + } + q.digit[(i - t) - 1] = DIGIT(tmp & _WORD(_MASK)); + } + + /* while (q{i-t-1} * (yt * b + y{t-1})) > + xi * b**2 + xi-1 * b + xi-2 + + do q{i-t-1} -= 1; + */ + + iter := 0; + + q.digit[(i - t) - 1] = (q.digit[(i - t) - 1] + 1) & _MASK; + #no_bounds_check for { + q.digit[(i - t) - 1] = (q.digit[(i - t) - 1] - 1) & _MASK; + + /* + Find left hand. + */ + zero(t1); + t1.digit[0] = ((t - 1) < 0) ? 0 : y.digit[t - 1]; + t1.digit[1] = y.digit[t]; + t1.used = 2; + if err = mul(t1, t1, q.digit[(i - t) - 1]); err != nil { return err; } + + /* + Find right hand. + */ + t2.digit[0] = ((i - 2) < 0) ? 0 : x.digit[i - 2]; + t2.digit[1] = x.digit[i - 1]; /* i >= 1 always holds */ + t2.digit[2] = x.digit[i]; + t2.used = 3; + + if t1_t2, _ := cmp_mag(t1, t2); t1_t2 != 1 { + break; + } + iter += 1; if iter > 100 { return .Max_Iterations_Reached; } + } + + /* + Step 3.3 x = x - q{i-t-1} * y * b**{i-t-1} + */ + if err = int_mul_digit(t1, y, q.digit[(i - t) - 1]); err != nil { return err; } + if err = shl_digit(t1, (i - t) - 1); err != nil { return err; } + if err = sub(x, x, t1); err != nil { return err; } + + /* + if x < 0 then { x = x + y*b**{i-t-1}; q{i-t-1} -= 1; } + */ + if x.sign == .Negative { + if err = copy(t1, y); err != nil { return err; } + if err = shl_digit(t1, (i - t) - 1); err != nil { return err; } + if err = add(x, x, t1); err != nil { return err; } + + q.digit[(i - t) - 1] = (q.digit[(i - t) - 1] - 1) & _MASK; + } + } + + /* + Now q is the quotient and x is the remainder, [which we have to normalize] + Get sign before writing to c. + */ + z, _ := is_zero(x); + x.sign = .Zero_or_Positive if z else numerator.sign; + + if quotient != nil { + clamp(q); + swap(q, quotient); + quotient.sign = .Negative if neg else .Zero_or_Positive; + } + + if remainder != nil { + if err = shr(x, x, norm); err != nil { return err; } + swap(x, remainder); + } + + return nil; +} + +/* + Slower bit-bang division... also smaller. +*/ +@(deprecated="Use `_int_div_school`, it's 3.5x faster.") +_private_int_div_small :: proc(quotient, remainder, numerator, denominator: ^Int) -> (err: Error) { + + ta, tb, tq, q := &Int{}, &Int{}, &Int{}, &Int{}; + c: int; + + goto_end: for { + if err = one(tq); err != nil { break goto_end; } + + num_bits, _ := count_bits(numerator); + den_bits, _ := count_bits(denominator); + n := num_bits - den_bits; + + if err = abs(ta, numerator); err != nil { break goto_end; } + if err = abs(tb, denominator); err != nil { break goto_end; } + if err = shl(tb, tb, n); err != nil { break goto_end; } + if err = shl(tq, tq, n); err != nil { break goto_end; } + + for n >= 0 { + if c, _ = cmp_mag(ta, tb); c == 0 || c == 1 { + // ta -= tb + if err = sub(ta, ta, tb); err != nil { break goto_end; } + // q += tq + if err = add( q, q, tq); err != nil { break goto_end; } + } + if err = shr1(tb, tb); err != nil { break goto_end; } + if err = shr1(tq, tq); err != nil { break goto_end; } + + n -= 1; + } + + /* + Now q == quotient and ta == remainder. + */ + neg := numerator.sign != denominator.sign; + if quotient != nil { + swap(quotient, q); + z, _ := is_zero(quotient); + quotient.sign = .Negative if neg && !z else .Zero_or_Positive; + } + if remainder != nil { + swap(remainder, ta); + z, _ := is_zero(numerator); + remainder.sign = .Zero_or_Positive if z else numerator.sign; + } + + break goto_end; + } + destroy(ta, tb, tq, q); + return err; +} + /* |