5

mock C++ function for unit test

 2 years ago
source link: https://byronhe.com/post/2014/03/27/mock-c-plus-plus-function-for-unit-test/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

mock C++ function for unit test

2014-03-27

在单元测试中,我们需要提供业务逻辑的mock版本, 当业务逻辑实现为C++的virtual function时,这是很容易的,我们只需要写一个子类, 实现virtual function就行了,Google 的 gmock就针对这种情况设计。

可是,如果遗留代码中有一般C函数,非virtual的类成员函数,模板函数,inline函数,如何提供mock版本呢?

下面的代码用一点trick实现了上述函数的运行时mock。

原理是,在运行时,修改目标函数的机器码,改为jmp到mock版本的函数中。

实现如下:

#include <stdint.h>
#include <iostream>
#include <string>
#include <unistd.h>
#include <string.h>
#include <sys/mman.h>
#include "patch_elf.h"
using namespace std;


int print_op(void * addr,int leng){
    unsigned char * op=(unsigned char *) addr;
    cout<<endl<<"addr:"<<addr<<" code:"<<endl;
    for(int i=0;i<leng;++i){
        cout<<"0x"<<hex<<(unsigned int)op[i]<<" ";
    }
    cout<<endl;
    return 0;
}

int patch_func(void  * original,void * mock){
    /*
    cout<<endl<<"----------------------------------------------------------------------"
        <<endl<<__func__<<" ,i am going to patch "<<original<<" to "<<mock
        <<endl;
        */

    //rax 用于保存函数调用的返回值,所以可以占用
    //4010e1:       b8 20 0c 40 00          mov    $0x400c20,%eax
    //4010e6:       ff e0                   jmpq   *%rax

    uint32_t addr=(uint32_t)(uint64_t)mock;
    const int code_len=7;
    char inject_code[code_len]={0xb8,0x00,0x00,0x00,0x00,0xff,0xe0};
    memcpy(&inject_code[1],(char*)&addr,4);

    //print_op(inject_code,code_len);

    //接下来,把inject_code复制到original这个位置
    //print_op(original,100);
    
    //首先,要改掉内存的权限,增加写权限
    const size_t length = sysconf(_SC_PAGESIZE);
    void * code_addr= (void*) ( ( (long) original/length)*length );
    int ret=mprotect(code_addr, length, PROT_READ | PROT_WRITE | PROT_EXEC);
    if ( 0!=ret ) {
        cerr<<"mprotect failed! ret="<<ret<<endl;
    }

    //修改代码
    memcpy( original,inject_code,code_len);

    //再去掉写权限
    ret=mprotect(code_addr, length, PROT_READ |  PROT_EXEC);
    if ( 0!=ret ) {
        cerr<<"mprotect failed! ret="<<ret<<endl;
    }

    //print_op(original,100);
    //cout<<"----------------------------------------------------------------------"
    //  <<endl<<endl;
    return 0;
}

1.一般函数 2.inline函数 3.一般成员函数 4.模板函数

并在 :32位,64位; -O2, -O0,参数下编译

除了 inline函数没办法,其它的都有效

#pragma once
#include <iostream>
#include <cstdio>
using namespace std;

class ST1{
    public:
        uint32_t a;
        uint64_t b;
        char c[200];
        double d;
        ST1 * e;

        ST1():a(0),b(0),d(0),e(0){
            c[0]=0;
        }

        //类的成员函数
        int member_func();
        int member_func_mock();
};

int member_func_extern(ST1 * st);


//一般函数
int original_func(ST1 * para1,ST1 para2,void * para3);
int mock_func(ST1 * para1,ST1 para2,void * para3);
//int ref_func(ST1 * para1,ST1 para2,void * para3);

class Base{
    private:
        uint32_t b;
        ST1 st;
    public:
        uint32_t a;

        Base():b(0),a(0){}
};

//inline 函数
inline int inline_func(int a,int b){
    int c=a+b+ 0x1111 * a + b/0x1111;
    printf("%s %d\n",__func__,c);
    return c;
}

inline int inline_func_mock(int a,int b){
    int c=a+b+100;
    printf("%s %d\n",__func__,c);
    return c;
}

//模板函数
template <typename T>
uint32_t get_member_a(T & t){
    cout<<__func__<<" a="<<t.a<<endl;
    return t.a;
}

template <typename T>
uint32_t get_member_b(T & t){
    cout<<__func__<<" b="<<t.b<<endl;
    return t.b;
}
#include <stdint.h>
#include <iostream>
#include <string>
#include <unistd.h>
#include "func.h"
using namespace std;

int original_func(ST1 * para1,ST1 para2,void * para3){
    cout<<__func__<<"\tcalled! "
        <<" a+a "<<para1->a+para2.a
        <<" b+b "<<para1->b+para2.b
        <<" c+c "<<para1->c<<para2.c
        <<" d+d "<<para1->d+para2.d
        <<" e+e "<<para1->e<<para2.e
        <<para3
        <<endl;
    return 0;
}


int mock_func(ST1 * para1,ST1 para2,void * para3){
    cout<<__func__<<"\tcalled!"
        <<endl;
    return 0;
}

int ref_func(ST1 * para1,ST1 para2,void * para3){
    return mock_func(para1,para2,para3);
}

int ST1::member_func(){
    cout<<__func__<<" called! "
        <<" a="<< this->a
        <<" b="<< this->b
        <<" c="<< this->c
        <<" d="<< this->d
        <<endl;
    return 0;
}

int ST1::member_func_mock(){
    cout<<__func__<<" called! i do nothing."
        <<endl;
    return 0;
}

int member_func_extern(ST1 * st){
    cout<<__func__<<" called! i am not member function."
        <<endl;
    return 0;
}
#include <stdint.h>
#include <iostream>
#include <string>
#include <unistd.h>
#include <string.h>
#include <sys/mman.h>

#include "func.h"
#include "patch_elf.h"

using namespace std;


int test(){
    ST1 s1,s2;
    char str[]="hello";
    s1.a=s1.b=s1.d=1;
    s1.e=NULL;

    s2.a=s2.b=s2.d=2;
    s1.e=NULL;

    cout<<"----------------------------------------------------------------------"<<endl;

    //mock original_func,替换成mock_func
    original_func(&s1,s2,&str[0]);

    patch_func((void*)&original_func, (void*)&mock_func);

    original_func(&s1,s2,&str[0]);

    cout<<"----------------------------------------------------------------------"<<endl;

    //mock inline 函数貌似不行
    int a=s1.a, b=s1.b;
    inline_func(a,b);
    patch_func( (void*) &inline_func, (void*) &inline_func_mock);
    inline_func(a,b);

    cout<<"----------------------------------------------------------------------"<<endl;

    s1.member_func();
    patch_func( (void*) &ST1::member_func, (void*) &ST1::member_func_mock);
    s1.member_func();
    patch_func( (void*) &ST1::member_func, (void*) &member_func_extern);
    s1.member_func();

    cout<<"----------------------------------------------------------------------"<<endl;

    get_member_a(s1);
    patch_func( (void*) & get_member_a<ST1> , (void*) & get_member_b<ST1>);
    get_member_a(s1);

    return 0;
}

int main(){
    test();
    return 0;
}

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK