LinuxSir.cn,穿越时空的Linuxsir!

 找回密码
 注册
搜索
热搜: shell linux mysql
查看: 1014|回复: 4

我的程序用C编写通过了,C++就是通不过,高手给指点一下!

[复制链接]
发表于 2003-4-17 17:28:30 | 显示全部楼层 |阅读模式
/线搜索的目的: 寻找步长,使新的迭代点的梯度比出发点小
#include<iostream.h>
#include<math.h>
class A;
typedef double (A::*PF)(double*);
class A
{
        public:
                A(){}
                double line_search(PF,int n,double * start,double *direction);//线搜索,start初始出发点
                                                                              //direction搜索方向   
                ~A(void){}
                void compute_grad(PF,int n,double *point);//计算函数f(X)的梯度
                double fun(double*);
        private:
                double a,value_a,diver_a;
                double b,value_b,diver_b;
                double t,value_t,diver_t;
                double s,w,z;
                double *grad,*temp_point;
};

double A::line_search(PF  pf,int n,double *start,double *direction)
{
        double h=1E-6;
        grad=new double[n];
        temp_point=new double[n];
        for(int i=0;i<n;i++)       
                temp_point=start+a*direction;//start即X0
        compute_grad(this->*pf,n,temp_point);//计算新的迭代点的梯度
        for(int i=0;i<n;i++)
                diver_a+=grad*direction;//f(X0+a*direction)关于a的导数
        do
        {
                b=a+h;
                for(int i=0;i<n;i++)
                        temp_point=start+b*direction;
                compute_grad(this->*pf,n,temp_point);
                for(int i=0;i<n;i++)
                        diver_b+=grad*direction;//关于b的导数
                if(fabs(diver_b)<=1E-10) break;//导数等于0
                if(diver_b<-1E-15)//导数小于0
                {
                        a=b;
                        diver_a=diver_b;
                        h*=2;
                }
        }while(fabs(diver_b)<1E-15);//以上部分是寻找区间[a,b]
        for(int i=0;i<n;i++)//以下从新的[a,b]出发
                temp_point=start+a*direction;
        value_a=(this->*pf)(temp_point);//f(a)
        for(int i=0;i<n;i++)
                temp_point=start+b*direction;
        compute_grad(this->*pf,n,temp_point);
        for(int i=0;i<n;i++)
                diver_b+=grad*direction;//b的导数
        value_b=(this->*pf)(temp_point);//f(b)
        do//两点三次插值法
        {
                s=3*(value_b-value_a)/(b-a);
                z=s-diver_a-diver_b;
                w=sqrt(z*z-diver_a*diver_b);
                t=a+(w-z-diver_a)*(b-a)/(diver_b-diver_a+2*w);
                for(int i=0;i<n;i++)
                        temp_point=start+t*direction;
                value_t=(this->*pf)(temp_point);
                compute_grad(this->*pf,n,temp_point);
                for(int i=0;i<n;i++)
                        diver_t+=grad*direction;//计算出来的t的导数,目的是它等于0则ok
                if(diver_t>1E-6)
                {
                        b=t;
                        diver_b=diver_t;
                        value_b=value_t;
                }
                else if(diver_t<-1E-6)
                {
                        a=t;
                        diver_a=diver_t;
                        value_a=value_t;
                }
                else break;
        }while(fabs(diver_t)>=1E-6&&fabs(b-a)>1E-6);
        return t;
}

void A::compute_grad(PF pf,int n,double *point)//求梯度公式
{
        double h=1E-3;
        double *temp=new double[n];
        for(int i=0;i<n;i++)
                temp=point;
        for(int i=0;i<n;i++)
        {
                temp+=(0.5*h);
                grad+=(4*(this->*pf)(temp)/(3*h));
                temp-=h;
                grad-=(4*(this->*pf)(temp)/(3*h));
                temp+=(3*h/2);
                grad-=((this->*pf)(temp))/(6*h);
                temp-=(2*h);
                grad+=((this->*pf)(temp))/(6*h);
        }
        delete[] temp;
}//g++编译器显示:ERROR:Aborted!

double A::fun(double* x)
{
        return x[0]*x[0]+2*x[1]*x[1]+3*x[2]*x[2]+4*x[3]*x[3];
}

void main()
{
        double x[4]={1,1,1,1};
        double p[4]={-2,-4,-6,-8};
        double landa;
        PF pf=&A::fun;
        A s;
    landa=s.line_search(pf,4,x,p);
        for(int i=0;i<4;i++)
                x+=landa*p;
        for(int i=0;i<4;i++)
                cout<<"x["<<i+1<<"]="<<x<<endl;
        cout<<"f(X)= "<<s.fun(x)<<endl;

}
发表于 2003-4-18 09:06:40 | 显示全部楼层
出什么错误提示?
发表于 2003-4-18 09:12:48 | 显示全部楼层
太长了!
抓住要点说
 楼主| 发表于 2003-4-18 18:06:49 | 显示全部楼层
显示错误:Aborted!
我是用g++编译的
发表于 2003-4-20 14:18:21 | 显示全部楼层
你的这个是计算方法的实验报告吧?
你可以在vc下编译试试
您需要登录后才可以回帖 登录 | 注册

本版积分规则

快速回复 返回顶部 返回列表