feat: change q-learning method to fit 'on 4 in a row'

This commit is contained in:
snsd0805 2023-06-02 23:34:37 +08:00
parent b024eec8e4
commit 0fb79b3e1e
Signed by: snsd0805
GPG Key ID: 569349933C77A854
6 changed files with 148 additions and 141 deletions

View File

@ -6,11 +6,12 @@
#define LAMBDA 0.9 // discount factor #define LAMBDA 0.9 // discount factor
#define STATE_NUM 19683 #define STATE_NUM 19683
#define ACTION_NUM 9 #define ACTION_NUM 7
#define EPISODE_NUM 100000 #define EPISODE_NUM 1000000
#define FIRST true #define FIRST true
#define ROW_NUM 6 #define ROW_NUM 6
#define COL_NUM 7 #define COL_NUM 7
#define BIGNUM_LEN 22 #define BIGNUM_LEN 22
#define TABLE_SIZE 1000000000

View File

@ -3,14 +3,8 @@
#include <string.h> #include <string.h>
#include <assert.h> #include <assert.h>
#include <time.h> #include <time.h>
#include "hash-table.h"
#define TABLE_SIZE 10 #include "constant.h"
struct Node {
char key[48];
int value;
struct Node *next;
};
long long hash_function(char *key) { long long hash_function(char *key) {
long long hash = 0; long long hash = 0;
@ -20,115 +14,69 @@ long long hash_function(char *key) {
return hash ; return hash ;
} }
void insert(struct Node **table, char *key, int value) { void insert(struct Node **table, char *key) {
long long hash = hash_function(key); long long hash = hash_function(key);
printf("Hash: %lli\n", hash);
struct Node *node = malloc(sizeof(struct Node)); struct Node *node = malloc(sizeof(struct Node));
struct Node *temp, *past; struct Node *temp, *past;
strcpy(node->key, key); strcpy(node->key, key);
node->value = value; // init
for (short i=0; i<ACTION_NUM; i++){
node->value[i] = 0.0;
}
node->next = NULL; node->next = NULL;
if (table[hash] == NULL){ if (table[hash] == NULL){
table[hash] = node; table[hash] = node;
printf("Create.\n");
} else { } else {
printf("Add.\n");
temp = table[hash]; temp = table[hash];
past = NULL; past = NULL;
while(temp != NULL){ while(temp != NULL){
assert(temp->key != key); assert(strcmp(temp->key, key)!=0);
printf("%s -> ", temp->key);
past = temp; past = temp;
temp = temp->next; temp = temp->next;
} }
printf("\n");
past->next = node; past->next = node;
} }
} }
void long_to_str(long long num, char *s, int length) { void search(struct Node **table, char *key, bool *find, float *ans) {
int temp;
for (int i=length-1; i>=0; i--) {
temp = num % 10;
num /= 10;
s[i] = (char)(temp + 48);
}
}
int search(struct Node **table, char *key) {
long long hash = hash_function(key); long long hash = hash_function(key);
struct Node *temp, *past; struct Node *temp, *past;
*find = false;
if (table[hash] == NULL){ if (table[hash] != NULL){
return -1;
} else {
temp = table[hash]; temp = table[hash];
past = NULL; past = NULL;
while(temp != NULL){ while(temp != NULL){
// printf("%s - %s\n", temp->key, key);
if (strcmp(temp->key, key) == 0){ if (strcmp(temp->key, key) == 0){
return temp->value; *find = true;
for (short i=0; i<ACTION_NUM; i++){
ans[i] = temp->value[i];
}
break;
} }
past = temp; past = temp;
temp = temp->next; temp = temp->next;
} }
return -1;
} }
} }
void update(struct Node **table, char *key, int value) { void update(struct Node **table, char *key, short action, float value) {
long long hash = hash_function(key); long long hash = hash_function(key);
struct Node *temp, *past; struct Node *temp, *past;
assert(table[hash]!=NULL);
temp = table[hash]; temp = table[hash];
past = NULL; past = NULL;
while(temp != NULL){ while(temp != NULL){
if (strcmp(temp->key, key) == 0){ if (strcmp(temp->key, key) == 0){
temp->value = value; temp->value[action] = value;
break; break;
} }
past = temp; past = temp;
temp = temp->next; temp = temp->next;
} }
} }
int main(){
struct Node ** table; // pointer to pointer
int size;
srand(time(NULL));
table = malloc(TABLE_SIZE * sizeof(struct Node*));
for (int i=0; i<TABLE_SIZE; i++){
table[i] = NULL;
}
long long a = 1234567890;
char s[21];
for (int i=0; i<50; i++){
a = (long long)rand();
printf("%lli\n", a);
long_to_str(a, s, 20);
printf("%s\n", s);
insert(table, s, i);
printf("\n");
}
int ans;
while (1) {
printf("> ");
scanf("%lli", &a);
printf("HERE\n");
long_to_str(a, s, 20);
printf("HERE\n");
update(table, s, 100);
ans = search(table, s);
printf("%d\n\n", ans);
}
// long long a = hash_function("9999999999999");
// printf("%lli\n", a);
}

13
hash-table.h Normal file
View File

@ -0,0 +1,13 @@
#include "constant.h"
#include <stdbool.h>
struct Node {
char key[BIGNUM_LEN+1];
float value[ACTION_NUM];
struct Node *next;
};
long long hash_function(char *key);
void insert(struct Node **table, char *key);
void search(struct Node **table, char *key, bool *find, float *ans);
void update(struct Node **table, char *key, short action, float value);

22
main.c
View File

@ -7,13 +7,21 @@
#include "q-learning.h" #include "q-learning.h"
int main(){ int main(){
short board[9]= {0}; // tic tac toe's chessboard short board[ROW_NUM][COL_NUM]= {0};
float table[STATE_NUM][ACTION_NUM]; // q-learning table short winner;
struct Node ** map; // pointer to pointer, hash table
bool find;
float state[ACTION_NUM];
srand(time(NULL)); srand(time(NULL));
init_table(&table[0][0]);
run(&table[0][0], board, false, 10000, false); // init hash table
run(&table[0][0], board, true, EPISODE_NUM, false); map = malloc(TABLE_SIZE * sizeof(struct Node*));
run(&table[0][0], board, false, 10000, false); for (int i=0; i<TABLE_SIZE; i++){
map[i] = NULL;
}
run(map, &board[0][0], false, 10000, false);
run(map, &board[0][0], true, EPISODE_NUM, false);
run(map, &board[0][0], false, 10000, true);
} }

View File

@ -6,6 +6,7 @@
#include "constant.h" #include "constant.h"
#include "enviroment.h" #include "enviroment.h"
#include "hash-table.h"
/* /*
Return the index with the max value in the array Return the index with the max value in the array
@ -37,36 +38,49 @@ short float_argmax(float *arr, short length){
Args: Args:
- short *table (array's address): state table for Q-Learning - short *table (array's address): state table for Q-Learning
- short *board (array's address): chessboards' status - short *board (array's address): chessboards' status
- int state (integer, state hash): hash for board's status - char *state (string, state hash): hash for board's status
Results: Results:
- short best_choice - short best_choice
*/ */
short bot_choose_action(float *table, short *board, int state){ short bot_choose_action(struct Node **map, short *board, char *state){
// get available actions for choosing // get available actions for choosing
short available_actions[9]; short available_actions[ACTION_NUM];
short available_actions_length; short available_actions_length;
get_available_actions(board, available_actions, &available_actions_length); get_available_actions(board, available_actions, &available_actions_length);
// use argmax() to find the best choise, // use argmax() to find the best choise,
// first we should build an available_actions_state array for saving the state for all available choise. // first we should build an available_actions_state array for saving the state for all available choise.
float available_actions_state[9]; float available_actions_state[ACTION_NUM];
short available_actions_state_index[9]; short available_actions_state_index[ACTION_NUM];
short available_actions_state_length, index = 0; short available_actions_state_length, index = 0;
short temp_index, best_choice; short temp_index, best_choice;
bool zeros = true; bool zeros = true;
for (short i=0; i<available_actions_length; i++){ bool find;
temp_index = available_actions[i]; float state_weights[ACTION_NUM];
available_actions_state[index] = *(table + state * ACTION_NUM + temp_index);
if (available_actions_state[index] != 0.0){ // find weights in the hash table
zeros = false; search(map, state, &find, state_weights);
} if (!find) {
available_actions_state_index[index] = temp_index; for (short i=0; i<ACTION_NUM; i++){
index++; state_weights[i] = 0.0;
} }
best_choice = float_argmax(available_actions_state, index); }
best_choice = available_actions_state_index[best_choice];
// get the best choice
for (short i=0; i<available_actions_length; i++){
temp_index = available_actions[i];
available_actions_state[index] = state_weights[temp_index];
if (available_actions_state[index] != 0.0){
zeros = false;
}
available_actions_state_index[index] = temp_index;
index++;
}
best_choice = float_argmax(available_actions_state, index);
best_choice = available_actions_state_index[best_choice];
// Epsilon-Greedy // Epsilon-Greedy
// If random number > EPSILON -> random a action // If random number > EPSILON -> random a action
@ -83,17 +97,15 @@ short bot_choose_action(float *table, short *board, int state){
Opponent random choose a action to do. Opponent random choose a action to do.
Args: Args:
- short *table (array's address): state table for Q-Learning
- short *board (array's address): chessboards' status - short *board (array's address): chessboards' status
- int state (integer, state hash): hash for board's status
Results: Results:
- short choice (integer): random, -1 means no available action to choose - short choice (integer): random, -1 means no available action to choose
*/ */
short opponent_random_action(float *table, short *board, int state){ short opponent_random_action(short *board){
// get available actions for choosing // get available actions for choosing
short available_actions[9]; short available_actions[ACTION_NUM];
short available_action_length; short available_action_length;
get_available_actions(board, available_actions, &available_action_length); get_available_actions(board, available_actions, &available_action_length);
@ -109,22 +121,24 @@ short opponent_random_action(float *table, short *board, int state){
return choice; return choice;
} }
/* // Use Hash Table, so we needn't initilize Q-Table
Inilialize the Q-Table //
// /*
// Inilialize the Q-Table
Args: // Args:
- float *table (two-dim array's start address) // - float *table (two-dim array's start address)
Results: // Results:
- None. // - None.
*/ // */
void init_table(float *table){ // void init_table(float *table){
for (int i=0; i<STATE_NUM; i++){ // for (int i=0; i<STATE_NUM; i++){
for (int j=0; j<ACTION_NUM; j++){ // for (int j=0; j<ACTION_NUM; j++){
*(table + i * ACTION_NUM + j) = 0; // *(table + i * ACTION_NUM + j) = 0;
} // }
} // }
} // }
/* /*
Give the chessboard & state, it will return the max reward with the best choice Give the chessboard & state, it will return the max reward with the best choice
@ -137,14 +151,24 @@ void init_table(float *table){
Results: Results:
- int max_reward - int max_reward
*/ */
float get_estimate_reward(float *table, short *board, int state){ float get_estimate_reward(struct Node **map, short *board, char *state){
short available_actions[9]; short available_actions[ACTION_NUM];
short available_action_length; short available_action_length;
get_available_actions(board, available_actions, &available_action_length); get_available_actions(board, available_actions, &available_action_length);
float available_actions_state[9]; // find weights in the hash table
float state_weights[ACTION_NUM];
bool find;
search(map, state, &find, state_weights);
if (!find) {
for (short i=0; i<ACTION_NUM; i++){
state_weights[i] = 0.0;
}
}
float available_actions_state[ACTION_NUM];
for (short i=0; i<available_action_length; i++){ for (short i=0; i<available_action_length; i++){
available_actions_state[i] = *(table + state * ACTION_NUM + available_actions[i]); // table[state][available_actions[i]] available_actions_state[i] = state_weights[available_actions[i]]; // table[state][available_actions[i]]
} }
short ans_index; short ans_index;
@ -165,37 +189,46 @@ float get_estimate_reward(float *table, short *board, int state){
Results: Results:
- None - None
*/ */
void run(float *table, short *board, bool train, int times, bool plot){ void run(struct Node **map, short *board, bool train, int times, bool plot){
short available_actions[9]; short available_actions[ACTION_NUM];
short available_actions_length; short available_actions_length;
short winner; short winner;
short choice, opponent_choice; short choice, opponent_choice;
int state, _state; char state[BIGNUM_LEN], _state[BIGNUM_LEN];
float estimate_r, estimate_r_, real_r, r, opponent_r; float estimate_r, estimate_r_, real_r, r, opponent_r;
struct action a; struct action a;
float state_weights[ACTION_NUM];
bool find;
int win = 0; int win = 0;
for (int episode=0; episode<times; episode++){ for (int episode=0; episode<times; episode++){
reset(board); reset(board);
state = state_hash(board); state_hash(board, state);
while (1){ while (1){
// bot choose the action // bot choose the action
choice = bot_choose_action(table, board, state); choice = bot_choose_action(map, board, state);
a.loc = choice; a.loc = choice;
a.player = BOT_SYMBOL; a.player = BOT_SYMBOL;
estimate_r = *(table + state * ACTION_NUM + choice); search(map, state, &find, state_weights);
act(board, &a, &_state, &r, &opponent_r, &winner); if (!find) {
for (short i=0; i<ACTION_NUM; i++){
state_weights[i] = 0.0;
}
if (train)
insert(map, state);
}
estimate_r = state_weights[choice];
act(board, &a, _state, &r, &opponent_r, &winner);
if (plot) show(board); if (plot) show(board);
// opponent random // // opponent random
if (winner == 0){ if (winner == 0){
opponent_choice = opponent_random_action(table, board, state_hash(board)); opponent_choice = opponent_random_action(board);
if (opponent_choice != -1){ if (opponent_choice != -1){
a.loc = opponent_choice; a.loc = opponent_choice;
a.player = OPPONENT_SYMBOL; a.player = OPPONENT_SYMBOL;
act(board, &a, &_state, &opponent_r, &r, &winner); act(board, &a, _state, &opponent_r, &r, &winner);
if (plot) show(board); if (plot) show(board);
} }
} }
@ -208,17 +241,19 @@ void run(float *table, short *board, bool train, int times, bool plot){
} }
real_r = r; real_r = r;
} else { } else {
estimate_r_ = get_estimate_reward(table, board, _state); estimate_r_ = get_estimate_reward(map, board, _state);
real_r = r + LAMBDA * estimate_r_; real_r = r + LAMBDA * estimate_r_;
} }
if (train){ if (train){
// printf("update"); state_weights[choice] += (LR * (real_r - estimate_r));
*(table + state * ACTION_NUM + choice) += ( LR * (real_r - estimate_r) ); // table[state][choice] += LR * (real_r - estimate_r) update(map, state, choice, state_weights[choice]);
} }
state = _state; for (int i=0; i<BIGNUM_LEN; i++){
state[i] = _state[i];
}
if ((winner != 0) || (available_actions_length == 0)){ if ((winner != 0) || (available_actions_length == 0)){
// printf("break\n");
if (winner == 1){ if (winner == 1){
win += 1; win += 1;
} }
@ -228,5 +263,6 @@ void run(float *table, short *board, bool train, int times, bool plot){
} }
if (!train) if (!train)
printf("%d/%d, %f\%\n", win, 10000, (float)win/10000); // printf("%d/%d, %f\%\n", win, 10000, (float)win/10000);
printf("%f\n", (float)win/times);
} }

View File

@ -1,6 +1,7 @@
#include "hash-table.h"
short float_argmax(float *arr, short length); short float_argmax(float *arr, short length);
short bot_choose_action(float *table, short *board, int state); short bot_choose_action(struct Node **map, short *board, char *state);
short opponent_random_action(float *table, short *board, int state); short opponent_random_action(short *board);
void init_table(float *table); float get_estimate_reward(struct Node **map, short *board, char *state);
float get_estimate_reward(float *table, short *board, int state); void run(struct Node **map, short *board, bool train, int times, bool plot);
void run(float *table, short *board, bool train, int times, bool plot);