feat: change q-learning method to fit 'on 4 in a row'
This commit is contained in:
parent
b024eec8e4
commit
0fb79b3e1e
@ -6,11 +6,12 @@
|
|||||||
#define LAMBDA 0.9 // discount factor
|
#define LAMBDA 0.9 // discount factor
|
||||||
|
|
||||||
#define STATE_NUM 19683
|
#define STATE_NUM 19683
|
||||||
#define ACTION_NUM 9
|
#define ACTION_NUM 7
|
||||||
#define EPISODE_NUM 100000
|
#define EPISODE_NUM 1000000
|
||||||
#define FIRST true
|
#define FIRST true
|
||||||
|
|
||||||
#define ROW_NUM 6
|
#define ROW_NUM 6
|
||||||
#define COL_NUM 7
|
#define COL_NUM 7
|
||||||
|
|
||||||
#define BIGNUM_LEN 22
|
#define BIGNUM_LEN 22
|
||||||
|
#define TABLE_SIZE 1000000000
|
||||||
|
|||||||
90
hash-table.c
90
hash-table.c
@ -3,14 +3,8 @@
|
|||||||
#include <string.h>
|
#include <string.h>
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
#include <time.h>
|
#include <time.h>
|
||||||
|
#include "hash-table.h"
|
||||||
#define TABLE_SIZE 10
|
#include "constant.h"
|
||||||
|
|
||||||
struct Node {
|
|
||||||
char key[48];
|
|
||||||
int value;
|
|
||||||
struct Node *next;
|
|
||||||
};
|
|
||||||
|
|
||||||
long long hash_function(char *key) {
|
long long hash_function(char *key) {
|
||||||
long long hash = 0;
|
long long hash = 0;
|
||||||
@ -20,115 +14,69 @@ long long hash_function(char *key) {
|
|||||||
return hash ;
|
return hash ;
|
||||||
}
|
}
|
||||||
|
|
||||||
void insert(struct Node **table, char *key, int value) {
|
void insert(struct Node **table, char *key) {
|
||||||
long long hash = hash_function(key);
|
long long hash = hash_function(key);
|
||||||
|
|
||||||
printf("Hash: %lli\n", hash);
|
|
||||||
|
|
||||||
struct Node *node = malloc(sizeof(struct Node));
|
struct Node *node = malloc(sizeof(struct Node));
|
||||||
struct Node *temp, *past;
|
struct Node *temp, *past;
|
||||||
strcpy(node->key, key);
|
strcpy(node->key, key);
|
||||||
node->value = value;
|
// init
|
||||||
|
for (short i=0; i<ACTION_NUM; i++){
|
||||||
|
node->value[i] = 0.0;
|
||||||
|
}
|
||||||
node->next = NULL;
|
node->next = NULL;
|
||||||
|
|
||||||
if (table[hash] == NULL){
|
if (table[hash] == NULL){
|
||||||
table[hash] = node;
|
table[hash] = node;
|
||||||
printf("Create.\n");
|
|
||||||
} else {
|
} else {
|
||||||
printf("Add.\n");
|
|
||||||
temp = table[hash];
|
temp = table[hash];
|
||||||
past = NULL;
|
past = NULL;
|
||||||
while(temp != NULL){
|
while(temp != NULL){
|
||||||
assert(temp->key != key);
|
assert(strcmp(temp->key, key)!=0);
|
||||||
printf("%s -> ", temp->key);
|
|
||||||
past = temp;
|
past = temp;
|
||||||
temp = temp->next;
|
temp = temp->next;
|
||||||
}
|
}
|
||||||
printf("\n");
|
|
||||||
past->next = node;
|
past->next = node;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void long_to_str(long long num, char *s, int length) {
|
void search(struct Node **table, char *key, bool *find, float *ans) {
|
||||||
int temp;
|
|
||||||
for (int i=length-1; i>=0; i--) {
|
|
||||||
temp = num % 10;
|
|
||||||
num /= 10;
|
|
||||||
s[i] = (char)(temp + 48);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int search(struct Node **table, char *key) {
|
|
||||||
long long hash = hash_function(key);
|
long long hash = hash_function(key);
|
||||||
struct Node *temp, *past;
|
struct Node *temp, *past;
|
||||||
|
*find = false;
|
||||||
|
|
||||||
if (table[hash] == NULL){
|
if (table[hash] != NULL){
|
||||||
return -1;
|
|
||||||
} else {
|
|
||||||
temp = table[hash];
|
temp = table[hash];
|
||||||
past = NULL;
|
past = NULL;
|
||||||
|
|
||||||
while(temp != NULL){
|
while(temp != NULL){
|
||||||
// printf("%s - %s\n", temp->key, key);
|
|
||||||
if (strcmp(temp->key, key) == 0){
|
if (strcmp(temp->key, key) == 0){
|
||||||
return temp->value;
|
*find = true;
|
||||||
|
for (short i=0; i<ACTION_NUM; i++){
|
||||||
|
ans[i] = temp->value[i];
|
||||||
|
}
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
past = temp;
|
past = temp;
|
||||||
temp = temp->next;
|
temp = temp->next;
|
||||||
}
|
}
|
||||||
return -1;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void update(struct Node **table, char *key, int value) {
|
void update(struct Node **table, char *key, short action, float value) {
|
||||||
long long hash = hash_function(key);
|
long long hash = hash_function(key);
|
||||||
struct Node *temp, *past;
|
struct Node *temp, *past;
|
||||||
|
assert(table[hash]!=NULL);
|
||||||
|
|
||||||
temp = table[hash];
|
temp = table[hash];
|
||||||
past = NULL;
|
past = NULL;
|
||||||
while(temp != NULL){
|
while(temp != NULL){
|
||||||
if (strcmp(temp->key, key) == 0){
|
if (strcmp(temp->key, key) == 0){
|
||||||
temp->value = value;
|
temp->value[action] = value;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
past = temp;
|
past = temp;
|
||||||
temp = temp->next;
|
temp = temp->next;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(){
|
|
||||||
struct Node ** table; // pointer to pointer
|
|
||||||
int size;
|
|
||||||
srand(time(NULL));
|
|
||||||
|
|
||||||
table = malloc(TABLE_SIZE * sizeof(struct Node*));
|
|
||||||
for (int i=0; i<TABLE_SIZE; i++){
|
|
||||||
table[i] = NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
long long a = 1234567890;
|
|
||||||
char s[21];
|
|
||||||
for (int i=0; i<50; i++){
|
|
||||||
a = (long long)rand();
|
|
||||||
printf("%lli\n", a);
|
|
||||||
long_to_str(a, s, 20);
|
|
||||||
printf("%s\n", s);
|
|
||||||
insert(table, s, i);
|
|
||||||
printf("\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
int ans;
|
|
||||||
while (1) {
|
|
||||||
printf("> ");
|
|
||||||
scanf("%lli", &a);
|
|
||||||
printf("HERE\n");
|
|
||||||
long_to_str(a, s, 20);
|
|
||||||
printf("HERE\n");
|
|
||||||
|
|
||||||
update(table, s, 100);
|
|
||||||
ans = search(table, s);
|
|
||||||
printf("%d\n\n", ans);
|
|
||||||
}
|
|
||||||
// long long a = hash_function("9999999999999");
|
|
||||||
// printf("%lli\n", a);
|
|
||||||
}
|
|
||||||
13
hash-table.h
Normal file
13
hash-table.h
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
#include "constant.h"
|
||||||
|
#include <stdbool.h>
|
||||||
|
|
||||||
|
struct Node {
|
||||||
|
char key[BIGNUM_LEN+1];
|
||||||
|
float value[ACTION_NUM];
|
||||||
|
struct Node *next;
|
||||||
|
};
|
||||||
|
|
||||||
|
long long hash_function(char *key);
|
||||||
|
void insert(struct Node **table, char *key);
|
||||||
|
void search(struct Node **table, char *key, bool *find, float *ans);
|
||||||
|
void update(struct Node **table, char *key, short action, float value);
|
||||||
22
main.c
22
main.c
@ -7,13 +7,21 @@
|
|||||||
#include "q-learning.h"
|
#include "q-learning.h"
|
||||||
|
|
||||||
int main(){
|
int main(){
|
||||||
short board[9]= {0}; // tic tac toe's chessboard
|
short board[ROW_NUM][COL_NUM]= {0};
|
||||||
float table[STATE_NUM][ACTION_NUM]; // q-learning table
|
short winner;
|
||||||
|
struct Node ** map; // pointer to pointer, hash table
|
||||||
|
bool find;
|
||||||
|
float state[ACTION_NUM];
|
||||||
|
|
||||||
srand(time(NULL));
|
srand(time(NULL));
|
||||||
init_table(&table[0][0]);
|
|
||||||
|
|
||||||
run(&table[0][0], board, false, 10000, false);
|
// init hash table
|
||||||
run(&table[0][0], board, true, EPISODE_NUM, false);
|
map = malloc(TABLE_SIZE * sizeof(struct Node*));
|
||||||
run(&table[0][0], board, false, 10000, false);
|
for (int i=0; i<TABLE_SIZE; i++){
|
||||||
|
map[i] = NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
run(map, &board[0][0], false, 10000, false);
|
||||||
|
run(map, &board[0][0], true, EPISODE_NUM, false);
|
||||||
|
run(map, &board[0][0], false, 10000, true);
|
||||||
}
|
}
|
||||||
|
|||||||
146
q-learning.c
146
q-learning.c
@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
#include "constant.h"
|
#include "constant.h"
|
||||||
#include "enviroment.h"
|
#include "enviroment.h"
|
||||||
|
#include "hash-table.h"
|
||||||
|
|
||||||
/*
|
/*
|
||||||
Return the index with the max value in the array
|
Return the index with the max value in the array
|
||||||
@ -37,36 +38,49 @@ short float_argmax(float *arr, short length){
|
|||||||
Args:
|
Args:
|
||||||
- short *table (array's address): state table for Q-Learning
|
- short *table (array's address): state table for Q-Learning
|
||||||
- short *board (array's address): chessboards' status
|
- short *board (array's address): chessboards' status
|
||||||
- int state (integer, state hash): hash for board's status
|
- char *state (string, state hash): hash for board's status
|
||||||
|
|
||||||
Results:
|
Results:
|
||||||
- short best_choice
|
- short best_choice
|
||||||
*/
|
*/
|
||||||
short bot_choose_action(float *table, short *board, int state){
|
short bot_choose_action(struct Node **map, short *board, char *state){
|
||||||
|
|
||||||
// get available actions for choosing
|
// get available actions for choosing
|
||||||
short available_actions[9];
|
short available_actions[ACTION_NUM];
|
||||||
short available_actions_length;
|
short available_actions_length;
|
||||||
get_available_actions(board, available_actions, &available_actions_length);
|
get_available_actions(board, available_actions, &available_actions_length);
|
||||||
|
|
||||||
// use argmax() to find the best choise,
|
// use argmax() to find the best choise,
|
||||||
// first we should build an available_actions_state array for saving the state for all available choise.
|
// first we should build an available_actions_state array for saving the state for all available choise.
|
||||||
float available_actions_state[9];
|
float available_actions_state[ACTION_NUM];
|
||||||
short available_actions_state_index[9];
|
short available_actions_state_index[ACTION_NUM];
|
||||||
short available_actions_state_length, index = 0;
|
short available_actions_state_length, index = 0;
|
||||||
short temp_index, best_choice;
|
short temp_index, best_choice;
|
||||||
bool zeros = true;
|
bool zeros = true;
|
||||||
for (short i=0; i<available_actions_length; i++){
|
bool find;
|
||||||
temp_index = available_actions[i];
|
float state_weights[ACTION_NUM];
|
||||||
available_actions_state[index] = *(table + state * ACTION_NUM + temp_index);
|
|
||||||
if (available_actions_state[index] != 0.0){
|
// find weights in the hash table
|
||||||
zeros = false;
|
search(map, state, &find, state_weights);
|
||||||
}
|
if (!find) {
|
||||||
available_actions_state_index[index] = temp_index;
|
for (short i=0; i<ACTION_NUM; i++){
|
||||||
index++;
|
state_weights[i] = 0.0;
|
||||||
}
|
}
|
||||||
best_choice = float_argmax(available_actions_state, index);
|
}
|
||||||
best_choice = available_actions_state_index[best_choice];
|
|
||||||
|
// get the best choice
|
||||||
|
for (short i=0; i<available_actions_length; i++){
|
||||||
|
temp_index = available_actions[i];
|
||||||
|
|
||||||
|
available_actions_state[index] = state_weights[temp_index];
|
||||||
|
if (available_actions_state[index] != 0.0){
|
||||||
|
zeros = false;
|
||||||
|
}
|
||||||
|
available_actions_state_index[index] = temp_index;
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
best_choice = float_argmax(available_actions_state, index);
|
||||||
|
best_choice = available_actions_state_index[best_choice];
|
||||||
|
|
||||||
// Epsilon-Greedy
|
// Epsilon-Greedy
|
||||||
// If random number > EPSILON -> random a action
|
// If random number > EPSILON -> random a action
|
||||||
@ -83,17 +97,15 @@ short bot_choose_action(float *table, short *board, int state){
|
|||||||
Opponent random choose a action to do.
|
Opponent random choose a action to do.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- short *table (array's address): state table for Q-Learning
|
|
||||||
- short *board (array's address): chessboards' status
|
- short *board (array's address): chessboards' status
|
||||||
- int state (integer, state hash): hash for board's status
|
|
||||||
|
|
||||||
Results:
|
Results:
|
||||||
- short choice (integer): random, -1 means no available action to choose
|
- short choice (integer): random, -1 means no available action to choose
|
||||||
*/
|
*/
|
||||||
short opponent_random_action(float *table, short *board, int state){
|
short opponent_random_action(short *board){
|
||||||
|
|
||||||
// get available actions for choosing
|
// get available actions for choosing
|
||||||
short available_actions[9];
|
short available_actions[ACTION_NUM];
|
||||||
short available_action_length;
|
short available_action_length;
|
||||||
get_available_actions(board, available_actions, &available_action_length);
|
get_available_actions(board, available_actions, &available_action_length);
|
||||||
|
|
||||||
@ -109,22 +121,24 @@ short opponent_random_action(float *table, short *board, int state){
|
|||||||
return choice;
|
return choice;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
// Use Hash Table, so we needn't initilize Q-Table
|
||||||
Inilialize the Q-Table
|
//
|
||||||
|
// /*
|
||||||
|
// Inilialize the Q-Table
|
||||||
|
|
||||||
Args:
|
// Args:
|
||||||
- float *table (two-dim array's start address)
|
// - float *table (two-dim array's start address)
|
||||||
|
|
||||||
Results:
|
// Results:
|
||||||
- None.
|
// - None.
|
||||||
*/
|
// */
|
||||||
void init_table(float *table){
|
// void init_table(float *table){
|
||||||
for (int i=0; i<STATE_NUM; i++){
|
// for (int i=0; i<STATE_NUM; i++){
|
||||||
for (int j=0; j<ACTION_NUM; j++){
|
// for (int j=0; j<ACTION_NUM; j++){
|
||||||
*(table + i * ACTION_NUM + j) = 0;
|
// *(table + i * ACTION_NUM + j) = 0;
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
/*
|
/*
|
||||||
Give the chessboard & state, it will return the max reward with the best choice
|
Give the chessboard & state, it will return the max reward with the best choice
|
||||||
@ -137,14 +151,24 @@ void init_table(float *table){
|
|||||||
Results:
|
Results:
|
||||||
- int max_reward
|
- int max_reward
|
||||||
*/
|
*/
|
||||||
float get_estimate_reward(float *table, short *board, int state){
|
float get_estimate_reward(struct Node **map, short *board, char *state){
|
||||||
short available_actions[9];
|
short available_actions[ACTION_NUM];
|
||||||
short available_action_length;
|
short available_action_length;
|
||||||
get_available_actions(board, available_actions, &available_action_length);
|
get_available_actions(board, available_actions, &available_action_length);
|
||||||
|
|
||||||
float available_actions_state[9];
|
// find weights in the hash table
|
||||||
|
float state_weights[ACTION_NUM];
|
||||||
|
bool find;
|
||||||
|
search(map, state, &find, state_weights);
|
||||||
|
if (!find) {
|
||||||
|
for (short i=0; i<ACTION_NUM; i++){
|
||||||
|
state_weights[i] = 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float available_actions_state[ACTION_NUM];
|
||||||
for (short i=0; i<available_action_length; i++){
|
for (short i=0; i<available_action_length; i++){
|
||||||
available_actions_state[i] = *(table + state * ACTION_NUM + available_actions[i]); // table[state][available_actions[i]]
|
available_actions_state[i] = state_weights[available_actions[i]]; // table[state][available_actions[i]]
|
||||||
}
|
}
|
||||||
|
|
||||||
short ans_index;
|
short ans_index;
|
||||||
@ -165,37 +189,46 @@ float get_estimate_reward(float *table, short *board, int state){
|
|||||||
Results:
|
Results:
|
||||||
- None
|
- None
|
||||||
*/
|
*/
|
||||||
void run(float *table, short *board, bool train, int times, bool plot){
|
void run(struct Node **map, short *board, bool train, int times, bool plot){
|
||||||
short available_actions[9];
|
short available_actions[ACTION_NUM];
|
||||||
short available_actions_length;
|
short available_actions_length;
|
||||||
short winner;
|
short winner;
|
||||||
short choice, opponent_choice;
|
short choice, opponent_choice;
|
||||||
int state, _state;
|
char state[BIGNUM_LEN], _state[BIGNUM_LEN];
|
||||||
float estimate_r, estimate_r_, real_r, r, opponent_r;
|
float estimate_r, estimate_r_, real_r, r, opponent_r;
|
||||||
struct action a;
|
struct action a;
|
||||||
|
float state_weights[ACTION_NUM];
|
||||||
|
bool find;
|
||||||
int win = 0;
|
int win = 0;
|
||||||
|
|
||||||
for (int episode=0; episode<times; episode++){
|
for (int episode=0; episode<times; episode++){
|
||||||
reset(board);
|
reset(board);
|
||||||
state = state_hash(board);
|
state_hash(board, state);
|
||||||
while (1){
|
while (1){
|
||||||
// bot choose the action
|
// bot choose the action
|
||||||
choice = bot_choose_action(table, board, state);
|
choice = bot_choose_action(map, board, state);
|
||||||
a.loc = choice;
|
a.loc = choice;
|
||||||
a.player = BOT_SYMBOL;
|
a.player = BOT_SYMBOL;
|
||||||
|
|
||||||
estimate_r = *(table + state * ACTION_NUM + choice);
|
search(map, state, &find, state_weights);
|
||||||
act(board, &a, &_state, &r, &opponent_r, &winner);
|
if (!find) {
|
||||||
|
for (short i=0; i<ACTION_NUM; i++){
|
||||||
|
state_weights[i] = 0.0;
|
||||||
|
}
|
||||||
|
if (train)
|
||||||
|
insert(map, state);
|
||||||
|
}
|
||||||
|
estimate_r = state_weights[choice];
|
||||||
|
act(board, &a, _state, &r, &opponent_r, &winner);
|
||||||
if (plot) show(board);
|
if (plot) show(board);
|
||||||
|
|
||||||
// opponent random
|
// // opponent random
|
||||||
if (winner == 0){
|
if (winner == 0){
|
||||||
opponent_choice = opponent_random_action(table, board, state_hash(board));
|
opponent_choice = opponent_random_action(board);
|
||||||
if (opponent_choice != -1){
|
if (opponent_choice != -1){
|
||||||
a.loc = opponent_choice;
|
a.loc = opponent_choice;
|
||||||
a.player = OPPONENT_SYMBOL;
|
a.player = OPPONENT_SYMBOL;
|
||||||
act(board, &a, &_state, &opponent_r, &r, &winner);
|
act(board, &a, _state, &opponent_r, &r, &winner);
|
||||||
if (plot) show(board);
|
if (plot) show(board);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -208,17 +241,19 @@ void run(float *table, short *board, bool train, int times, bool plot){
|
|||||||
}
|
}
|
||||||
real_r = r;
|
real_r = r;
|
||||||
} else {
|
} else {
|
||||||
estimate_r_ = get_estimate_reward(table, board, _state);
|
estimate_r_ = get_estimate_reward(map, board, _state);
|
||||||
real_r = r + LAMBDA * estimate_r_;
|
real_r = r + LAMBDA * estimate_r_;
|
||||||
}
|
}
|
||||||
if (train){
|
if (train){
|
||||||
// printf("update");
|
state_weights[choice] += (LR * (real_r - estimate_r));
|
||||||
*(table + state * ACTION_NUM + choice) += ( LR * (real_r - estimate_r) ); // table[state][choice] += LR * (real_r - estimate_r)
|
update(map, state, choice, state_weights[choice]);
|
||||||
}
|
}
|
||||||
state = _state;
|
for (int i=0; i<BIGNUM_LEN; i++){
|
||||||
|
state[i] = _state[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if ((winner != 0) || (available_actions_length == 0)){
|
if ((winner != 0) || (available_actions_length == 0)){
|
||||||
// printf("break\n");
|
|
||||||
if (winner == 1){
|
if (winner == 1){
|
||||||
win += 1;
|
win += 1;
|
||||||
}
|
}
|
||||||
@ -228,5 +263,6 @@ void run(float *table, short *board, bool train, int times, bool plot){
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!train)
|
if (!train)
|
||||||
printf("%d/%d, %f\%\n", win, 10000, (float)win/10000);
|
// printf("%d/%d, %f\%\n", win, 10000, (float)win/10000);
|
||||||
|
printf("%f\n", (float)win/times);
|
||||||
}
|
}
|
||||||
|
|||||||
11
q-learning.h
11
q-learning.h
@ -1,6 +1,7 @@
|
|||||||
|
#include "hash-table.h"
|
||||||
|
|
||||||
short float_argmax(float *arr, short length);
|
short float_argmax(float *arr, short length);
|
||||||
short bot_choose_action(float *table, short *board, int state);
|
short bot_choose_action(struct Node **map, short *board, char *state);
|
||||||
short opponent_random_action(float *table, short *board, int state);
|
short opponent_random_action(short *board);
|
||||||
void init_table(float *table);
|
float get_estimate_reward(struct Node **map, short *board, char *state);
|
||||||
float get_estimate_reward(float *table, short *board, int state);
|
void run(struct Node **map, short *board, bool train, int times, bool plot);
|
||||||
void run(float *table, short *board, bool train, int times, bool plot);
|
|
||||||
Loading…
Reference in New Issue
Block a user