C 语言

关注公众号 jb51net

关闭
首页 > 软件编程 > C 语言 > C++压位优化

详解C/C++高精度(加减乘除)算法中的压位优化

作者:CodeOfCC

在高精度计算中数组的每个元素存储一位10进制的数字,这样的存储方式并不是最优的,32位的整型其实至少可以存储9位高精度数字,数组元素存储更多的位数就是压位优化。本文将展示压位优化的原理以及压9位的实现和性能对比,需要的可以参考一下

前言

由于上一章《C/C++ 高精度(加减乘除)算法简单实现》实现了基本的高精度计算,数组的每个元素存储一位10进制的数字。这样的存储方式并不是最优的,32位的整型其实至少可以存储9位高精度数字,数组元素存储更多的位数就是压位优化。本文将展示压位优化的原理以及压9位的实现和性能对比。

一、基本原理

1、存储方式

压位优化就是将原本存储一位的数组元素变成存储多位,这样就可以提升运算效率,通常最高能存储9位,如下图示例为存储4位。

2、计算方式

采用模拟立竖式计算,比如加法的计算流程,如下图所示,20481024+80001000=100482024:

二、完整代码

因为接口以及使用方法与上一章《C/C++ 高精度(加减乘除)算法简单实现》是完全一致的,所以这里直接展示完整代码,省略使用示例。下面代码为压9位实现。

#include<stdio.h>
#include<string.h>
#include<math.h>
#include<stdint.h>
/// <summary>
/// 通过字符串初始化
/// </summary>
/// <param name="a">[in]高精度数组</param>
/// <param name="value">[in]字符串首地址</param>
static void loadStr(int* a, const char* value) {
	int len = strlen(value);
	int left = len % 9;
	char s[8], * p = (char*)value + left;
	a[0] = ceil(len / 9.0);
	len = len / 9.0;
	for (int i = 1; i <= len; i++)
		sscanf(p + (len - i ) * 9, "%09d", &a[i]);
	if (left){
		sprintf(s, "%%0%dd", left);
		sscanf(value, s, &a[a[0]]);
	}
}
/// <summary>
/// 输出到字符串,
/// </summary>
/// <param name="a">[in]高精度数组</param>
/// <param name="str">[out]字符串,由外部停供缓冲区,需要保证长度足够</param>
static void toStr(int* a, char* str) {
	if (!a[0]) {
		sprintf(str, "0");
		return;
	}
	sprintf(str, "%d", a[a[0]]);
	str += strlen(str);
	for (int i = a[0]-1; i > 0; i--)
		sprintf(str +( a[0] -i-1)*9, "%09d", a[i]);
	str[(a[0]-1)*9] = '\0';
}
/// <summary>
/// 通过无符号整型初始化
/// </summary>
/// <param name="a">[in]高精度数组</param>
/// <param name="value">[in]整型值</param>
static void loadInt(int* a, uint64_t value) {
	a[0] = 0;
	while (value)a[++a[0]] = value % 1000000000, value /= 1000000000;
}

/// <summary>
/// 比较两个高精度数的大小
/// </summary>
/// <param name="a">[in]第一个数</param>
/// <param name="b">[in]第二个数</param>
/// <returns>1是a>b,0是a==b,-1是a<b</returns>
static int compare(int* a, int* b) {
	if (a[0] > b[0])return 1;
	if (a[0] < b[0])return -1;
	for (int i = a[0]; i > 0; i--)
		if (a[i] > b[i])return 1;
		else if (a[i] < b[i])return -1;
	return 0;
}
/// <summary>
/// 复制
/// </summary>
/// <param name="a">[in]源</param>
/// <param name="b">[in]目标</param>
static void copy(int* a, int* b) {
	memcpy(b, a, (a[0] + 1) * sizeof(int));
}
/// <summary>
/// 打印输出结果
/// </summary>
static void print(int* a) {
	int i = a[0];
	printf("%d", a[i--]);
	for (; i > 0; i--)printf("%09d", a[i]);
}

/// <summary>
/// 加
/// </summary>
/// <param name="a">[in]被加数</param>
/// <param name="b">[in]加数</param>
/// <param name="c">[out]结果</param>
static	void plus(int* a, int* b, int* c) {
	int* p;
	if (a[0] < b[0])p = a, a = b, b = p;//确保a长度最大	
	int i = 1, alen = a[0], blen = b[0];
	c[0] = c[alen + 1] = 0;
	if (a != c)memcpy(c + blen + 1, a + blen + 1, (alen - blen) * sizeof(int));//a多出的部分直接拷贝到结果
	for (; i <= blen; i++) {
		c[i] = a[i] + b[i];
		if (c[i - 1] >= 1000000000)c[i - 1] -= 1000000000, c[i]++;//判断上一位是否进位	
	}
	i--;
	while (c[i++] >= 1000000000)c[i - 1] -= 1000000000, c[i]++;//继续判断进位
	c[0] = c[alen + 1] ? alen + 1 : alen;//记录长度
}
/// <summary>
/// 加等于
///结果会保存在a中
/// </summary>
/// <param name="a">[in]被加数</param>
/// <param name="b">[in]加数</param>
static	void plusq(int* a, int* b) {
	plus(a, b, a);
}

/// <summary>
/// 减
/// </summary>
/// <param name="a">[in]被减数,被减数必须大于等于减数</param>
/// <param name="b">[in]减数</param>
/// <param name="c">[out]结果</param>
static	void sub(int* a, int* b, int* c) {
	int i = 1, alen = a[0];
	if (a != c)memcpy(c + b[0] + 1, a + b[0] + 1, (a[0] - b[0]) * sizeof(int));//a多出的部分直接拷贝到结果
	c[0] = 1;
	for (; i <= b[0]; i++) {
		c[i] = a[i] - b[i];
		if (c[i - 1] < 0)c[i - 1] += 1000000000, c[i] --;//判断上一位是否补位		
	}
	i--;
	while (c[i++] < 0)c[i - 1] += 1000000000, c[i]--;//继续判断补位	
	while (!c[alen--]); c[0] = alen + 1;//记录长度
}
/// <summary>
/// 减法等于
///结果会保存在a中
/// </summary>
/// <param name="a">[in]被减数,被减数必须大于等于减数</param>
/// <param name="b">[in]减数</param>
static	void subq(int* a, int* b) {
	sub(a, b, a);
}

/// <summary>
/// 乘
/// </summary>
/// <param name="a">[in]被乘数</param>
/// <param name="b">[in]乘数</param>
/// <param name="c">[out]结果,数组长度必须大于等于aLen+bLen+1</param>
static	void mul(int* a, int* b, int c[]) {
	int len = a[0] + b[0], d = 0;
	memset(c, 0, sizeof(int) * (len + 1));
	b[b[0] + 1] = 0; c[0] = 1;//防止越界
	for (int i = 1; i <= a[0]; i++)
		for (int j = 1; j <= b[0] + 1; j++){
			int64_t t = (int64_t)a[i] * b[j] + c[j + i - 1] + d;
			c[j + i - 1] = t % 1000000000;
			d = t / 1000000000;
		}
	while (!c[len])len--; c[0] = len;
}
/// <summary>
/// 乘等于
/// 累乘,结果存放于a
/// </summary>
/// <param name="a">[in]被乘数,数组长度必须大于等于2aLen+bLen+1</param>
/// <param name="b">[in]乘数</param>
static	void mulq(int* a, int* b) {
	int* c = a + a[0] + b[0] + 1;
	memcpy(c, a, (a[0] + 1) * sizeof(int));
	mul(c, b, a);
}

/// <summary>
/// 除法
/// 依赖减法subq
/// </summary>
/// <param name="a">[in]被除数,被除数必须大于除数</param>
/// <param name="b">[in]除数</param>
/// <param name="c">[out]商,数组长度大于等于3aLen-bLen+1</param>
/// <param name="mod">[out]余数,可以为NULL,数组长度大于等于aLen</param>>
static void div(int* a, int* b, int* c, int* mod) {
	int len = a[0] - b[0] + 1, times, hTimes[32], * temp = c + a[0] + 1;
	if (!mod)mod = temp + 2*(a[0] + 1)+1;//缓冲区
	memcpy(mod, a, (a[0] + 1) * sizeof(int));
	memset(c, 0, sizeof(int) * (len + 1));
	memset(temp, 0, sizeof(int) * len);
	c[0] = 1;//防止while越界
	for (int i = len; i > 0; i--) {
		memcpy(temp + i, b + 1, sizeof(int) * b[0]);//升阶	
		temp[0] = b[0] + i - 1;
		while (compare(mod, temp) != -1) {
			if (times = (mod[mod[0]] * ((mod[0] - temp[0]) ? 1000000000ll : 1)) / (temp[temp[0]] + (temp[0] == 1 ? 0 : 1)))//升倍数
			{
				loadInt(hTimes,times);
				mulq(temp, hTimes);
			}
			else times = 1;
			while (compare(mod, temp) != -1)subq(mod, temp), c[i] += times;	//减法
			memcpy(temp + i, b + 1, sizeof(int) * b[0]);//还原	
			temp[0] = b[0] + i - 1;
		}
	}
	while (!c[len])len--; c[0] = len;
}
/// <summary>
/// 除等于
/// 商保存在a
/// 依赖div
/// </summary>
/// <param name="a">[in]被除数,被除数必须大于除数</param>
/// <param name="b">[in]除数</param>
/// <param name="mod">[out]余数,可以为NULL,数组长度大于等于aLen</param>>
static void divq(int* a, int* b, int* mod) {
	div(a, b, a, mod);
}

三、性能对比

测试平台:Windows 11

测试设备:i7 8750h

测试方式:测试5次取均值

表1、测试用例

测试用例描述
1整型范围数字计算500000次
2长数字与整型范围数字计算500000次
3长数字与长数字计算500000次

基于上述用例编写程序进行测试,测试结果如下表

表2、测试结果

计算测试用例1位实现(上一章)耗时9位优化(本章)耗时
加法测试用例10.003926s0.002620s
加法测试用例20.026735s0.005711s
加法测试用例30.029378s0.005384s
累加测试用例10.003255s0.002536s
累加测试用例20.017843s0.002592s
累加测试用例30.034025s0.006474s
减法测试用例10.004237s0.002078s
减法测试用例20.024775s0.004939s
减法测试用例30.027634s0.004929s
累减测试用例10.004272s0.002034s
累减测试用例20.0054070.001942s
累减测试用例30.019363s0.004282s
乘法测试用例10.043608s0.004751s
乘法测试用例20.479071s0.028358s
乘法测试用例33.375447s0.064259s
累乘测试用例1 只计算1000次0.001237s0.000137s
累乘测试用例2 只计算1000次0.001577s0.000187s
累乘测试用例3 只计算1000次5.792887s0.081988s
除法测试用例10.025391s0.024763s
除法测试用例25.292809s0.516090s
除法测试用例30.395773s0.073812s
累除测试用例1 只计算1000次0.059054s0.035722s
累除测试用例2 只计算1000次0.103727s0.060936s
累除测试用例3 只计算1000次89.748837s25.126072s

将上表数据进行分类相同类型取均值计算出提升速度如下图所示,仅作参考。

图1、速度提升

总结

以上就是今天要讲的内容,压位优化性能提升是比较显著的,而且实现也很容易,大部分逻辑是一致的只是底数变大了而已。从性能测试结果来看所有计算至少由4倍的提升,乘法性能提升较大有可能是测试方法不严重,这个待以后验证。总的来说,对高精度运算进行压位优化还是很有必要的,尤其是对时间和空间有要求的场景还是比较适用的。

附录 1、性能测试代码

#include<Windows.h>
#include <iostream>
static int a[819200];
static int b[819200];
static int c[819200];
static int mod[819200];
static char str[81920];
/// <summary>
/// 返回当前时间
/// </summary>
/// <returns>当前时间,单位秒,精度微秒</returns>
static double  getCurrentTime()
{
	LARGE_INTEGER ticks, Frequency;
	QueryPerformanceFrequency(&Frequency);
	QueryPerformanceCounter(&ticks);
	return  (double)ticks.QuadPart / (double)Frequency.QuadPart;
}
/// <summary>
/// 性能测试
/// </summary>
static void test() {
	double d = getCurrentTime();
	loadStr(a, "50000");
	loadInt(b, 50000);
	for (int64_t i = 1; i <= 500000; i++) {
		plus(a, b, c);
	}
	printf("plus  performence   1: %llfs\n", getCurrentTime() - d);
	d = getCurrentTime();
	loadStr(a, "999999999999999999999999999999999999999999999999999999999999999999");
	loadInt(b, 5);
	for (int64_t i = 1; i <= 500000; i++) {
		plus(a, b, c);
	}
	printf("plus  performence   2: %llfs\n", getCurrentTime() - d);

	d = getCurrentTime();
	loadStr(a, "999999999999999999999999999999999999999999999999999999999999999999");
	loadStr(b, "11111111111111111111111111111111111111");
	for (int64_t i = 1; i <= 500000; i++) {
		plus(b, a, c);
	}
	printf("plus  performence   3: %llfs\n", getCurrentTime() - d);


	d = getCurrentTime();
	loadStr(a, "50000");
	loadInt(b, 50000);
	for (int64_t i = 1; i <= 500000; i++) {
		plusq(a, b);
	}
	printf("plusq performence   1: %llfs\n", getCurrentTime() - d);

	d = getCurrentTime();
	loadStr(a, "999999999999999999999999999999999999999999999999999999999999999999");
	for (int64_t i = 500000000; i <= 500000000 + 500000; i++) {
		loadInt(b, i);
		plusq(a, b);
	}
	printf("plusq performence   2: %llfs\n", getCurrentTime() - d);

	d = getCurrentTime();
	loadStr(b, "999999999999999999999999999999999999999999999999999999999999999999");
	for (int64_t i = 500000000; i <= 500000000 + 500000; i++) {
		loadInt(a, i);
		plusq(a, b);

	}
	printf("plusq performence   3: %llfs\n", getCurrentTime() - d);


	d = getCurrentTime();
	loadStr(a, "50000");
	loadInt(b, 10000);
	for (int64_t i = 1; i <= 500000; i++) {
		sub(a, b, c);
	}
	printf("sub   performence   1: %llfs\n", getCurrentTime() - d);



	d = getCurrentTime();
	loadStr(a, "100000000000000000000000000000000000000000000000000000000000000000");
	loadInt(b, 11111);
	for (int64_t i = 1; i <= 500000; i++) {
		sub(a, b, c);
	}
	printf("sub   performence   2: %llfs\n", getCurrentTime() - d);


	d = getCurrentTime();
	loadStr(a, "100000000000000000000000000000000000000000000000000000000000000000");
	loadStr(b, "11111111111111111111111111111111111111");
	for (int64_t i = 1; i <= 500000; i++) {
		sub(a, b, c);
	}
	printf("sub   performence   3: %llfs\n", getCurrentTime() - d);



	d = getCurrentTime();
	loadStr(a, "50000000000");
	loadInt(b, 500000);
	for (int64_t i = 1; i <= 500000; i++) {
		subq(a, b);
	}
	printf("subq  performence   1: %llfs\n", getCurrentTime() - d);

	d = getCurrentTime();
	loadStr(a, "100000000000000000000000000000000000000000000000000000000000000000");
	loadInt(b, 11111);
	for (int64_t i = 1; i <= 500000; i++) {
		subq(a, b);
	}
	printf("subq  performence   2: %llfs\n", getCurrentTime() - d);


	
	d = getCurrentTime();
	loadStr(a, "100000000000000000000000000000000000000000000000000000000000000000");
	loadStr(b, "11111111111111111111111111111111111111");
	for (int64_t i = 1; i <= 500000; i++) {
		subq(a, b);
	}
	printf("subq  performence   3: %llfs\n", getCurrentTime() - d);


	d = getCurrentTime();
	loadStr(a, "50000");
	loadInt(b, 12345);
	for (int64_t i = 1; i <= 500000; i++) {
		mul(a, b, c);
	}
	printf("mul   performence   1: %llfs\n", getCurrentTime() - d);

	d = getCurrentTime();
	loadStr(a, "999999999999999999999999999999999999999999999999999999999999999999");
	loadInt(b, 12345);
	for (int64_t i = 1; i <= 500000; i++) {
		mul(a, b, c);
	}
	printf("mul   performence   2: %llfs\n", getCurrentTime() - d);

	d = getCurrentTime();
	loadStr(a, "999999999999999999999999999999999999999999999999999999999999999999");
	loadStr(b, "11111111111111111111111111111111111111");
	for (int64_t i = 1; i <= 500000; i++) {
		mul(b, a, c);
	}
	printf("mul   performence   3: %llfs\n", getCurrentTime() - d);


	d = getCurrentTime();
	loadStr(a, "2");
	loadInt(b, 2);
	for (int64_t i = 1; i <= 1000; i++) {
		mulq(a, b);
	}
	printf("mulq  performence   1: %llfs\n", getCurrentTime() - d);
	d = getCurrentTime();
	loadStr(a, "999999999999999999999999999999999999999999999999999999999999999999");
	loadInt(b, 2);
	for (int64_t i = 1; i <= 1000; i++) {
		mulq(a, b);
	}
	printf("mulq  performence   2: %llfs\n", getCurrentTime() - d);

	d = getCurrentTime();
	loadStr(a, "999999999999999999999999999999999999999999999999999999999999999999");
	loadStr(b, "11111111111111111111111111111111111111");
	for (int64_t i = 1; i <= 1000; i++) {
		mulq(b, a);
	}
	printf("mulq  performence   3: %llfs\n", getCurrentTime() - d);



	d = getCurrentTime();
	loadStr(a, "50000");
	loadInt(b, 12345);
	for (int64_t i = 1; i <= 500000; i++) {
		div(a, b, c, mod);
	}
	printf("div   performence   1: %llfs\n", getCurrentTime() - d);

	d = getCurrentTime();
	loadStr(a, "100000000000000000000000000000000000000000000000000000000000000000");
	loadInt(b, 12345);
	for (int64_t i = 1; i <= 500000; i++) {
		div(a, b, c, NULL);
	}
	printf("div   performence   2: %llfs\n", getCurrentTime() - d);

	d = getCurrentTime();
	loadStr(a, "100000000000000000000000000000000000000000000000000000000000000000");
	loadStr(b, "11111111111111111111111111111111111111");
	for (int64_t i = 1; i <= 500000; i++) {
		div(a, b, c, mod);
	}
	printf("div   performence   3: %llfs\n", getCurrentTime() - d);



	loadStr(a, "1");
	loadStr(b, "2");
	for (int64_t i = 1; i <= 1000; i++) {
		mulq(a, b);
	}
	d = getCurrentTime();
	for (int64_t i = 1; i <= 1000; i++) {
		divq(a, b, mod);
	}
	printf("divq  performence   1: %llfs\n", getCurrentTime() - d);

	loadStr(a, "999999999999999999999999999999999999999999999999999999999999999999");
	loadStr(b, "2");
	for (int64_t i = 1; i <= 1000; i++) {
		mulq(a, b);
	}
	d = getCurrentTime();
	for (int64_t i = 1; i <= 1000; i++) {
		divq(a, b, mod);
	}
	printf("divq  performence   2: %llfs\n", getCurrentTime() - d);



	loadStr(a, "999999999999999999999999999999999999999999999999999999999999999999");
	loadStr(b, "11111111111111111111111111111111111111");
	for (int64_t i = 1; i <= 500; i++) {
		mulq(a, b);
	}
	d = getCurrentTime();
	for (int64_t i = 1; i <= 500; i++) {
		divq(a, b, mod);
	}
	printf("divq  performence   3: %llfs\n", getCurrentTime() - d);
}

到此这篇关于详解C/C++高精度(加减乘除)算法中的压位优化的文章就介绍到这了,更多相关C++压位优化内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

您可能感兴趣的文章:
阅读全文