Merge: replace 'tic-tac-toe' with '4 in a row'
'tic-tac-toe' version is located on the other branch.
This commit is contained in:
commit
c3a0335ff1
4
Makefile
4
Makefile
@ -1,5 +1,5 @@
|
||||
all: a.out
|
||||
a.out: main.c enviroment.c enviroment.h q-learning.c q-learning.h constant.h
|
||||
gcc main.c enviroment.c q-learning.c -lm
|
||||
a.out: main.c enviroment.c enviroment.h q-learning.c q-learning.h bignum.c bignum.h hash-table.c hash-table.h constant.h
|
||||
gcc main.c enviroment.c q-learning.c bignum.c constant.h hash-table.c -lm
|
||||
run:
|
||||
./a.out
|
||||
|
||||
43
bignum.c
Normal file
43
bignum.c
Normal file
@ -0,0 +1,43 @@
|
||||
#include <stdio.h>
|
||||
#include <stdbool.h>
|
||||
#include "bignum.h"
|
||||
#include "constant.h"
|
||||
|
||||
struct BigNum long_to_BigNum(long long num) {
|
||||
struct BigNum ans;
|
||||
int temp;
|
||||
for (int i=BIGNUM_LEN-1; i>=0; i--) {
|
||||
temp = num % 10;
|
||||
num /= 10;
|
||||
ans.num[i] = (char)(temp + 48);
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
struct BigNum add(struct BigNum a, struct BigNum b) {
|
||||
struct BigNum ans;
|
||||
short s, carry=0;
|
||||
|
||||
for (short i=BIGNUM_LEN-1; i>=0; i--) {
|
||||
s = (a.num[i]-48) + (b.num[i]-48) + carry;
|
||||
carry = s / 10;
|
||||
s %= 10;
|
||||
ans.num[i] = (char)(s+48);
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
struct BigNum mul(struct BigNum a, int b) {
|
||||
struct BigNum ans;
|
||||
short s, carry=0;
|
||||
|
||||
for (short i=BIGNUM_LEN-1; i>=0; i--) {
|
||||
s = (a.num[i]-48) * b + carry;
|
||||
carry = s / 10;
|
||||
s %= 10;
|
||||
ans.num[i] = (char)(s+48);
|
||||
// printf("index(%hd): %c\n", i, (char)(s+48));
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
8
bignum.h
Normal file
8
bignum.h
Normal file
@ -0,0 +1,8 @@
|
||||
#include "constant.h"
|
||||
|
||||
struct BigNum {
|
||||
char num[BIGNUM_LEN+1];
|
||||
};
|
||||
struct BigNum long_to_BigNum(long long num);
|
||||
struct BigNum add(struct BigNum a, struct BigNum b);
|
||||
struct BigNum mul(struct BigNum a, int b);
|
||||
10
constant.h
10
constant.h
@ -6,6 +6,12 @@
|
||||
#define LAMBDA 0.9 // discount factor
|
||||
|
||||
#define STATE_NUM 19683
|
||||
#define ACTION_NUM 9
|
||||
#define EPISODE_NUM 100000
|
||||
#define ACTION_NUM 7
|
||||
#define EPISODE_NUM 1000000
|
||||
#define FIRST true
|
||||
|
||||
#define ROW_NUM 6
|
||||
#define COL_NUM 7
|
||||
|
||||
#define BIGNUM_LEN 22
|
||||
#define TABLE_SIZE 1000000000
|
||||
|
||||
253
enviroment.c
253
enviroment.c
@ -1,28 +1,34 @@
|
||||
#include "enviroment.h"
|
||||
#include "constant.h"
|
||||
#include <assert.h>
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <math.h>
|
||||
#include <assert.h>
|
||||
#include "constant.h"
|
||||
#include "enviroment.h"
|
||||
#include "bignum.h"
|
||||
|
||||
short PATHS[8][3] = {
|
||||
{ 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 },
|
||||
{ 0, 3, 6 }, { 1, 4, 7 }, { 2, 5, 8 },
|
||||
{ 0, 4, 8 }, { 2, 4, 6 }
|
||||
struct BigNum POWs[42] = {
|
||||
"0000000000000000000001", "0000000000000000000003", "0000000000000000000009", "0000000000000000000027", "0000000000000000000081",
|
||||
"0000000000000000000243", "0000000000000000000729", "0000000000000000002187", "0000000000000000006561", "0000000000000000019683",
|
||||
"0000000000000000059049", "0000000000000000177147", "0000000000000000531441", "0000000000000001594323", "0000000000000004782969",
|
||||
"0000000000000014348907", "0000000000000043046721", "0000000000000129140163", "0000000000000387420489", "0000000000001162261467",
|
||||
"0000000000003486784401", "0000000000010460353203", "0000000000031381059609", "0000000000094143178827", "0000000000282429536481",
|
||||
"0000000000847288609443", "0000000002541865828329", "0000000007625597484987", "0000000022876792454961", "0000000068630377364883",
|
||||
"0000000205891132094649", "0000000617673396283947", "0000001853020188851841", "0000005559060566555523", "0000016677181699666569",
|
||||
"0000050031545098999707", "0000150094635296999121", "0000450283905890997363", "0001350851717672992089", "0004052555153018976267",
|
||||
"0012157665459056928801", "0036472996377170786403"
|
||||
};
|
||||
|
||||
/*
|
||||
Reset the game, clear the chessboard.
|
||||
|
||||
Args:
|
||||
- short *board (array's address): chessboard's status
|
||||
Args:
|
||||
- short *board (array's start address): chessboard's status
|
||||
|
||||
Results:
|
||||
- None, set all blocks on the chessboard to zero.
|
||||
*/
|
||||
void reset(short* board)
|
||||
{
|
||||
for (short i = 0; i < 9; i++)
|
||||
board[i] = 0;
|
||||
void reset(short* board){
|
||||
for (short i=0; i<(ROW_NUM*COL_NUM); i++)
|
||||
board[i] = 0;
|
||||
}
|
||||
|
||||
/*
|
||||
@ -34,23 +40,23 @@ void reset(short* board)
|
||||
Results:
|
||||
- None. Only printing.
|
||||
*/
|
||||
void show(short* board)
|
||||
{
|
||||
short loc;
|
||||
printf("┼───┼───┼───┼\n");
|
||||
for (short i = 0; i < 3; i++) {
|
||||
printf("│ ");
|
||||
for (short j = 0; j < 3; j++) {
|
||||
loc = 3 * i + j;
|
||||
if (board[loc] == 0)
|
||||
printf(" │ ");
|
||||
else if (board[loc] == BOT_SYMBOL)
|
||||
printf("○ │ ");
|
||||
else
|
||||
printf("✕ │ ");
|
||||
void show(short *board){
|
||||
short loc;
|
||||
for (short i=0; i<COL_NUM; i++){
|
||||
printf("%d ", i);
|
||||
}
|
||||
printf("\n");
|
||||
for (short i=(ROW_NUM*COL_NUM-1); i>=0; i--){
|
||||
if (board[i] == BOT_SYMBOL) {
|
||||
printf("● ");
|
||||
} else if(board[i] == OPPONENT_SYMBOL) {
|
||||
printf("◴ ");
|
||||
} else {
|
||||
printf("◌ ");
|
||||
}
|
||||
if (i%COL_NUM == 0){
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
printf("┼───┼───┼───┼\n");
|
||||
}
|
||||
printf("\n\n");
|
||||
}
|
||||
@ -66,80 +72,177 @@ void show(short* board)
|
||||
Results:
|
||||
- None. All available actions are saved into "result" and the number of actions is saved in "length"
|
||||
*/
|
||||
void get_available_actions(short* board, short* result, short* length)
|
||||
{
|
||||
short index = 0;
|
||||
for (int i = 0; i < 9; i++)
|
||||
if (board[i] == 0)
|
||||
result[index++] = i;
|
||||
*length = index;
|
||||
|
||||
void get_available_actions(short *board, short *result, short *length){
|
||||
short index = 0;
|
||||
for (int i=0; i<COL_NUM; i++)
|
||||
if (board[(ROW_NUM*COL_NUM-1)-i] == 0)
|
||||
result[index++] = i;
|
||||
*length = index;
|
||||
}
|
||||
|
||||
/*
|
||||
Return winner's number;
|
||||
Get value in the board with validation.
|
||||
|
||||
Args:
|
||||
- short *board (array's start pointer): chessboard's status
|
||||
- short row (integer): loc's row number
|
||||
- short col (integer): loc's col number
|
||||
|
||||
Results:
|
||||
- short value (integer): means the value in chessboard[row][col]
|
||||
*/
|
||||
short get_loc_status(short *board, short row, short col) {
|
||||
if ((row >= ROW_NUM) || (row < 0)) {
|
||||
return -1;
|
||||
}
|
||||
if ((col >= COL_NUM) || (col < 0)) {
|
||||
return -1;
|
||||
}
|
||||
return board[row*COL_NUM+col];
|
||||
}
|
||||
|
||||
/*
|
||||
Return winner's number;
|
||||
|
||||
Args:
|
||||
- short *board (array's address): chessboard's status
|
||||
|
||||
Results:
|
||||
- short winner_number(integer): winner's number, 0 for no winner now, 1 for Bot, 2 for opponent
|
||||
Results:
|
||||
- short winner_number(integer): winner's number, 0 for no winner now, 1 for Bot, 2 for opponent
|
||||
|
||||
board's coodinate diagram
|
||||
^
|
||||
| 5
|
||||
| 4
|
||||
| 3
|
||||
| 2
|
||||
| 1
|
||||
| 0
|
||||
<-----------------------------
|
||||
6 5 4 3 2 1 0 |
|
||||
*/
|
||||
short get_winner(short* board)
|
||||
{
|
||||
int a, b, c;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
a = PATHS[i][0];
|
||||
b = PATHS[i][1];
|
||||
c = PATHS[i][2];
|
||||
if ((board[a] == board[b]) && (board[b] == board[c]) && (board[a] != 0)) {
|
||||
return board[a];
|
||||
short get_winner(short *board){
|
||||
short a, b, c, d;
|
||||
for (short i=0; i<ROW_NUM; i++){
|
||||
for (short j=0; j<COL_NUM; j++){
|
||||
// horizontal
|
||||
a = get_loc_status(board, i, j);
|
||||
b = get_loc_status(board, i, j+1);
|
||||
c = get_loc_status(board, i, j+2);
|
||||
d = get_loc_status(board, i, j+3);
|
||||
if ((a == b) && (b == c) && (c == d) && (a!=0)) {
|
||||
return a;
|
||||
}
|
||||
|
||||
// vertical
|
||||
a = get_loc_status(board, i, j);
|
||||
b = get_loc_status(board, i+1, j);
|
||||
c = get_loc_status(board, i+2, j);
|
||||
d = get_loc_status(board, i+3, j);
|
||||
if ((a == b) && (b == c) && (c == d) && (a!=0)) {
|
||||
return a;
|
||||
}
|
||||
|
||||
// slash (/)
|
||||
a = get_loc_status(board, i, j);
|
||||
b = get_loc_status(board, i+1, j-1);
|
||||
c = get_loc_status(board, i+2, j-2);
|
||||
d = get_loc_status(board, i+3, j-3);
|
||||
if ((a == b) && (b == c) && (c == d) && (a!=0)) {
|
||||
return a;
|
||||
}
|
||||
|
||||
// backslash (\)
|
||||
a = get_loc_status(board, i, j);
|
||||
b = get_loc_status(board, i+1, j+1);
|
||||
c = get_loc_status(board, i+2, j+2);
|
||||
d = get_loc_status(board, i+3, j+3);
|
||||
if ((a == b) && (b == c) && (c == d) && (a!=0)) {
|
||||
return a;
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
Hash chesstable's status into hash.
|
||||
|
||||
Args:
|
||||
- short *board (array's address): chessboard's status
|
||||
- char *hash (a string): size is BIGNUM_LEN, the hash will be wrote here
|
||||
|
||||
Results:
|
||||
- None.
|
||||
*/
|
||||
void state_hash(short *board, char *hash){
|
||||
struct BigNum sum, temp;
|
||||
for (short i=0; i<BIGNUM_LEN; i++){
|
||||
sum.num[i] = '0';
|
||||
}
|
||||
|
||||
for (short i=0; i<(ROW_NUM*COL_NUM); i++) {
|
||||
// printf("MUL:\n");
|
||||
// printf("%s\n", POWs[i].num);
|
||||
temp = mul(POWs[i], board[i]);
|
||||
// printf("%s\n\n", temp.num);
|
||||
|
||||
// printf("ADD:\n");
|
||||
// printf("%s\n", sum.num);
|
||||
// printf("%s\n", temp.num);
|
||||
sum = add(sum, temp);
|
||||
// printf("%s\n\n", sum.num);
|
||||
|
||||
}
|
||||
|
||||
for (int i=0; i<BIGNUM_LEN; i++){
|
||||
hash[i] = sum.num[i];
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Fall the chess on the board.
|
||||
|
||||
Args:
|
||||
- short *board (array's address): chessboard's status
|
||||
- short *board: chessboard
|
||||
- struct action *a (struct pointer): action's loc & player
|
||||
|
||||
Results:
|
||||
- int hash (integer): chessboard's status in i-th block * pow(3, i)
|
||||
- None. Fall chess on the chessboard
|
||||
*/
|
||||
int state_hash(short* board)
|
||||
{
|
||||
int base, hash = 0;
|
||||
for (int i = 0; i < 9; i++) {
|
||||
base = pow(3, i);
|
||||
hash += (base * board[i]);
|
||||
void fall(short *board, struct action *a) {
|
||||
short *ptr = (board + ROW_NUM * COL_NUM - 1 - (a->loc));
|
||||
while ((*ptr == 0) && (ptr>=board)) {
|
||||
// printf("%d ", *ptr);
|
||||
ptr -= COL_NUM;
|
||||
}
|
||||
return hash;
|
||||
*(ptr+COL_NUM) = a->player;
|
||||
}
|
||||
|
||||
/*
|
||||
Act on the chessboard.
|
||||
|
||||
Args:
|
||||
- short *board (array's address): chessboards' status
|
||||
- struct action *a (a action's pointer): include player & choose loc
|
||||
- int *state (pointer): for return. To save the chessboard's state hash which after doing this action
|
||||
- float *reward (pointer): for return. To save the number of rewards which the player gets after doing this action.
|
||||
- float *opponent_reward (pointer): for return. To save the number of rewards which the opponents gets after the player doing this action.
|
||||
- short *winner (pointer): for return. To save the winner in this action. If haven't finish, it will be zero.
|
||||
Args:
|
||||
- short *board (array's address): chessboards' status
|
||||
- struct action *a (a action's pointer): include player & choose loc
|
||||
- char *state (a string): for return. To save the chessboard's state hash which after doing this action
|
||||
- float *reward (pointer): for return. To save the number of rewards which the player gets after doing this action.
|
||||
- float *opponent_reward (pointer): for return. To save the number of rewards which the opponents gets after the player doing this action.
|
||||
- short *winner (pointer): for return. To save the winner in this action. If haven't finish, it will be zero.
|
||||
|
||||
Results:
|
||||
- None. Save in state & reward & winner
|
||||
*/
|
||||
void act(short* board, struct action* a, int* state, float* reward, float* opponent_reward, short* winner)
|
||||
{
|
||||
void act(short *board, struct action *a, char *state, float *reward, float *opponent_reward, short *winner){
|
||||
// printf("Act( player=%d, action=%d )\n", a->player, a->loc);
|
||||
assert(board[a->loc] == 0);
|
||||
board[a->loc] = a->player;
|
||||
*winner = get_winner(board);
|
||||
*state = state_hash(board);
|
||||
if (*winner == a->player) {
|
||||
*reward = 1.0;
|
||||
assert(board[(ROW_NUM*COL_NUM-1)-(a->loc)] == 0);
|
||||
|
||||
fall(board, a);
|
||||
*winner = get_winner(board);
|
||||
state_hash(board, state);
|
||||
if (*winner == a->player){
|
||||
*reward = 1.0;
|
||||
*opponent_reward = -1.0;
|
||||
} else if (*winner != 0) {
|
||||
*reward = -1.0;
|
||||
|
||||
10
enviroment.h
10
enviroment.h
@ -4,8 +4,8 @@ struct action {
|
||||
};
|
||||
|
||||
void reset(short* board);
|
||||
void show(short* board);
|
||||
void get_available_actions(short* board, short* result, short* length);
|
||||
short get_winner(short* board);
|
||||
int state_hash(short* board);
|
||||
void act(short* board, struct action* a, int* state, float* reward, float* opponent_reward, short* winner);
|
||||
void show(short *board);
|
||||
void get_available_actions(short *board, short *result, short *length);
|
||||
short get_winner(short *board);
|
||||
void state_hash(short *board, char *hash);
|
||||
void act(short *board, struct action *a, char *state, float *reward, float *opponent_reward, short *winner);
|
||||
|
||||
82
hash-table.c
Normal file
82
hash-table.c
Normal file
@ -0,0 +1,82 @@
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <assert.h>
|
||||
#include <time.h>
|
||||
#include "hash-table.h"
|
||||
#include "constant.h"
|
||||
|
||||
long long hash_function(char *key) {
|
||||
long long hash = 0;
|
||||
for (int i=0; i<strlen(key); i++){
|
||||
hash = ((hash * 33) + key[i]) % TABLE_SIZE;
|
||||
}
|
||||
return hash ;
|
||||
}
|
||||
|
||||
void insert(struct Node **table, char *key) {
|
||||
long long hash = hash_function(key);
|
||||
|
||||
|
||||
struct Node *node = malloc(sizeof(struct Node));
|
||||
struct Node *temp, *past;
|
||||
strcpy(node->key, key);
|
||||
// init
|
||||
for (short i=0; i<ACTION_NUM; i++){
|
||||
node->value[i] = 0.0;
|
||||
}
|
||||
node->next = NULL;
|
||||
|
||||
if (table[hash] == NULL){
|
||||
table[hash] = node;
|
||||
} else {
|
||||
temp = table[hash];
|
||||
past = NULL;
|
||||
while(temp != NULL){
|
||||
assert(strcmp(temp->key, key)!=0);
|
||||
past = temp;
|
||||
temp = temp->next;
|
||||
}
|
||||
past->next = node;
|
||||
}
|
||||
}
|
||||
|
||||
void search(struct Node **table, char *key, bool *find, float *ans) {
|
||||
long long hash = hash_function(key);
|
||||
struct Node *temp, *past;
|
||||
*find = false;
|
||||
|
||||
if (table[hash] != NULL){
|
||||
temp = table[hash];
|
||||
past = NULL;
|
||||
|
||||
while(temp != NULL){
|
||||
if (strcmp(temp->key, key) == 0){
|
||||
*find = true;
|
||||
for (short i=0; i<ACTION_NUM; i++){
|
||||
ans[i] = temp->value[i];
|
||||
}
|
||||
break;
|
||||
}
|
||||
past = temp;
|
||||
temp = temp->next;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void update(struct Node **table, char *key, short action, float value) {
|
||||
long long hash = hash_function(key);
|
||||
struct Node *temp, *past;
|
||||
assert(table[hash]!=NULL);
|
||||
|
||||
temp = table[hash];
|
||||
past = NULL;
|
||||
while(temp != NULL){
|
||||
if (strcmp(temp->key, key) == 0){
|
||||
temp->value[action] = value;
|
||||
break;
|
||||
}
|
||||
past = temp;
|
||||
temp = temp->next;
|
||||
}
|
||||
}
|
||||
13
hash-table.h
Normal file
13
hash-table.h
Normal file
@ -0,0 +1,13 @@
|
||||
#include "constant.h"
|
||||
#include <stdbool.h>
|
||||
|
||||
struct Node {
|
||||
char key[BIGNUM_LEN+1];
|
||||
float value[ACTION_NUM];
|
||||
struct Node *next;
|
||||
};
|
||||
|
||||
long long hash_function(char *key);
|
||||
void insert(struct Node **table, char *key);
|
||||
void search(struct Node **table, char *key, bool *find, float *ans);
|
||||
void update(struct Node **table, char *key, short action, float value);
|
||||
23
main.c
23
main.c
@ -6,15 +6,22 @@
|
||||
#include <stdlib.h>
|
||||
#include <time.h>
|
||||
|
||||
int main()
|
||||
{
|
||||
short board[9] = { 0 }; // tic tac toe's chessboard
|
||||
float table[STATE_NUM][ACTION_NUM]; // q-learning table
|
||||
int main(){
|
||||
short board[ROW_NUM][COL_NUM]= {0};
|
||||
short winner;
|
||||
struct Node ** map; // pointer to pointer, hash table
|
||||
bool find;
|
||||
float state[ACTION_NUM];
|
||||
|
||||
srand(time(NULL));
|
||||
init_table(&table[0][0]);
|
||||
|
||||
run(&table[0][0], board, false, 10000, false);
|
||||
run(&table[0][0], board, true, EPISODE_NUM, false);
|
||||
run(&table[0][0], board, false, 10000, false);
|
||||
// init hash table
|
||||
map = malloc(TABLE_SIZE * sizeof(struct Node*));
|
||||
for (int i=0; i<TABLE_SIZE; i++){
|
||||
map[i] = NULL;
|
||||
}
|
||||
|
||||
run(map, &board[0][0], false, 10000, false);
|
||||
run(map, &board[0][0], true, EPISODE_NUM, false);
|
||||
run(map, &board[0][0], false, 10000, true);
|
||||
}
|
||||
|
||||
183
q-learning.c
183
q-learning.c
@ -6,6 +6,7 @@
|
||||
|
||||
#include "constant.h"
|
||||
#include "enviroment.h"
|
||||
#include "hash-table.h"
|
||||
|
||||
/*
|
||||
Return the index with the max value in the array
|
||||
@ -34,33 +35,45 @@ short float_argmax(float* arr, short length)
|
||||
EPSILON means the probability to choose the best action in this state from Q-Table.
|
||||
(1-EPSILON) to random an action to do.
|
||||
|
||||
Args:
|
||||
- short *table (array's address): state table for Q-Learning
|
||||
- short *board (array's address): chessboards' status
|
||||
- int state (integer, state hash): hash for board's status
|
||||
Args:
|
||||
- short *table (array's address): state table for Q-Learning
|
||||
- short *board (array's address): chessboards' status
|
||||
- char *state (string, state hash): hash for board's status
|
||||
|
||||
Results:
|
||||
- short best_choice
|
||||
*/
|
||||
short bot_choose_action(float* table, short* board, int state)
|
||||
{
|
||||
short bot_choose_action(struct Node **map, short *board, char *state){
|
||||
|
||||
// get available actions for choosing
|
||||
short available_actions[9];
|
||||
short available_actions_length;
|
||||
get_available_actions(board, available_actions, &available_actions_length);
|
||||
// get available actions for choosing
|
||||
short available_actions[ACTION_NUM];
|
||||
short available_actions_length;
|
||||
get_available_actions(board, available_actions, &available_actions_length);
|
||||
|
||||
// use argmax() to find the best choise,
|
||||
// first we should build an available_actions_state array for saving the state for all available choise.
|
||||
float available_actions_state[9];
|
||||
short available_actions_state_index[9];
|
||||
short available_actions_state_length, index = 0;
|
||||
short temp_index, best_choice;
|
||||
bool zeros = true;
|
||||
for (short i = 0; i < available_actions_length; i++) {
|
||||
// use argmax() to find the best choise,
|
||||
// first we should build an available_actions_state array for saving the state for all available choise.
|
||||
float available_actions_state[ACTION_NUM];
|
||||
short available_actions_state_index[ACTION_NUM];
|
||||
short available_actions_state_length, index = 0;
|
||||
short temp_index, best_choice;
|
||||
bool zeros = true;
|
||||
bool find;
|
||||
float state_weights[ACTION_NUM];
|
||||
|
||||
// find weights in the hash table
|
||||
search(map, state, &find, state_weights);
|
||||
if (!find) {
|
||||
for (short i=0; i<ACTION_NUM; i++){
|
||||
state_weights[i] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
// get the best choice
|
||||
for (short i=0; i<available_actions_length; i++){
|
||||
temp_index = available_actions[i];
|
||||
available_actions_state[index] = *(table + state * ACTION_NUM + temp_index);
|
||||
if (available_actions_state[index] != 0.0) {
|
||||
|
||||
available_actions_state[index] = state_weights[temp_index];
|
||||
if (available_actions_state[index] != 0.0){
|
||||
zeros = false;
|
||||
}
|
||||
available_actions_state_index[index] = temp_index;
|
||||
@ -84,18 +97,15 @@ short bot_choose_action(float* table, short* board, int state)
|
||||
Opponent random choose a action to do.
|
||||
|
||||
Args:
|
||||
- short *table (array's address): state table for Q-Learning
|
||||
- short *board (array's address): chessboards' status
|
||||
- int state (integer, state hash): hash for board's status
|
||||
- short *board (array's address): chessboards' status
|
||||
|
||||
Results:
|
||||
- short choice (integer): random, -1 means no available action to choose
|
||||
*/
|
||||
short opponent_random_action(float* table, short* board, int state)
|
||||
{
|
||||
short opponent_random_action(short *board){
|
||||
|
||||
// get available actions for choosing
|
||||
short available_actions[9];
|
||||
short available_actions[ACTION_NUM];
|
||||
short available_action_length;
|
||||
get_available_actions(board, available_actions, &available_action_length);
|
||||
|
||||
@ -111,23 +121,24 @@ short opponent_random_action(float* table, short* board, int state)
|
||||
return choice;
|
||||
}
|
||||
|
||||
/*
|
||||
Inilialize the Q-Table
|
||||
// Use Hash Table, so we needn't initilize Q-Table
|
||||
//
|
||||
// /*
|
||||
// Inilialize the Q-Table
|
||||
|
||||
Args:
|
||||
- float *table (two-dim array's start address)
|
||||
// Args:
|
||||
// - float *table (two-dim array's start address)
|
||||
|
||||
Results:
|
||||
- None.
|
||||
*/
|
||||
void init_table(float* table)
|
||||
{
|
||||
for (int i = 0; i < STATE_NUM; i++) {
|
||||
for (int j = 0; j < ACTION_NUM; j++) {
|
||||
*(table + i * ACTION_NUM + j) = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Results:
|
||||
// - None.
|
||||
// */
|
||||
// void init_table(float *table){
|
||||
// for (int i=0; i<STATE_NUM; i++){
|
||||
// for (int j=0; j<ACTION_NUM; j++){
|
||||
// *(table + i * ACTION_NUM + j) = 0;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
/*
|
||||
Give the chessboard & state, it will return the max reward with the best choice
|
||||
@ -140,15 +151,24 @@ void init_table(float* table)
|
||||
Results:
|
||||
- int max_reward
|
||||
*/
|
||||
float get_estimate_reward(float* table, short* board, int state)
|
||||
{
|
||||
short available_actions[9];
|
||||
float get_estimate_reward(struct Node **map, short *board, char *state){
|
||||
short available_actions[ACTION_NUM];
|
||||
short available_action_length;
|
||||
get_available_actions(board, available_actions, &available_action_length);
|
||||
|
||||
float available_actions_state[9];
|
||||
for (short i = 0; i < available_action_length; i++) {
|
||||
available_actions_state[i] = *(table + state * ACTION_NUM + available_actions[i]); // table[state][available_actions[i]]
|
||||
// find weights in the hash table
|
||||
float state_weights[ACTION_NUM];
|
||||
bool find;
|
||||
search(map, state, &find, state_weights);
|
||||
if (!find) {
|
||||
for (short i=0; i<ACTION_NUM; i++){
|
||||
state_weights[i] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
float available_actions_state[ACTION_NUM];
|
||||
for (short i=0; i<available_action_length; i++){
|
||||
available_actions_state[i] = state_weights[available_actions[i]]; // table[state][available_actions[i]]
|
||||
}
|
||||
|
||||
short ans_index;
|
||||
@ -169,41 +189,47 @@ float get_estimate_reward(float* table, short* board, int state)
|
||||
Results:
|
||||
- None
|
||||
*/
|
||||
void run(float* table, short* board, bool train, int times, bool plot)
|
||||
{
|
||||
short available_actions[9];
|
||||
short available_actions_length;
|
||||
short winner;
|
||||
void run(struct Node **map, short *board, bool train, int times, bool plot){
|
||||
short available_actions[ACTION_NUM];
|
||||
short available_actions_length;
|
||||
short winner;
|
||||
short choice, opponent_choice;
|
||||
int state, _state;
|
||||
char state[BIGNUM_LEN], _state[BIGNUM_LEN];
|
||||
float estimate_r, estimate_r_, real_r, r, opponent_r;
|
||||
struct action a;
|
||||
|
||||
float state_weights[ACTION_NUM];
|
||||
bool find;
|
||||
int win = 0;
|
||||
|
||||
for (int episode = 0; episode < times; episode++) {
|
||||
reset(board);
|
||||
state = state_hash(board);
|
||||
while (1) {
|
||||
state_hash(board, state);
|
||||
while (1){
|
||||
// bot choose the action
|
||||
choice = bot_choose_action(table, board, state);
|
||||
choice = bot_choose_action(map, board, state);
|
||||
a.loc = choice;
|
||||
a.player = BOT_SYMBOL;
|
||||
|
||||
estimate_r = *(table + state * ACTION_NUM + choice);
|
||||
act(board, &a, &_state, &r, &opponent_r, &winner);
|
||||
if (plot)
|
||||
show(board);
|
||||
search(map, state, &find, state_weights);
|
||||
if (!find) {
|
||||
for (short i=0; i<ACTION_NUM; i++){
|
||||
state_weights[i] = 0.0;
|
||||
}
|
||||
if (train)
|
||||
insert(map, state);
|
||||
}
|
||||
estimate_r = state_weights[choice];
|
||||
act(board, &a, _state, &r, &opponent_r, &winner);
|
||||
if (plot) show(board);
|
||||
|
||||
// opponent random
|
||||
if (winner == 0) {
|
||||
opponent_choice = opponent_random_action(table, board, state_hash(board));
|
||||
if (opponent_choice != -1) {
|
||||
// // opponent random
|
||||
if (winner == 0){
|
||||
opponent_choice = opponent_random_action(board);
|
||||
if (opponent_choice != -1){
|
||||
a.loc = opponent_choice;
|
||||
a.player = OPPONENT_SYMBOL;
|
||||
act(board, &a, &_state, &opponent_r, &r, &winner);
|
||||
if (plot)
|
||||
show(board);
|
||||
act(board, &a, _state, &opponent_r, &r, &winner);
|
||||
if (plot) show(board);
|
||||
}
|
||||
}
|
||||
get_available_actions(board, available_actions, &available_actions_length);
|
||||
@ -215,18 +241,20 @@ void run(float* table, short* board, bool train, int times, bool plot)
|
||||
}
|
||||
real_r = r;
|
||||
} else {
|
||||
estimate_r_ = get_estimate_reward(table, board, _state);
|
||||
estimate_r_ = get_estimate_reward(map, board, _state);
|
||||
real_r = r + LAMBDA * estimate_r_;
|
||||
}
|
||||
if (train) {
|
||||
// printf("update");
|
||||
*(table + state * ACTION_NUM + choice) += (LR * (real_r - estimate_r)); // table[state][choice] += LR * (real_r - estimate_r)
|
||||
if (train){
|
||||
state_weights[choice] += (LR * (real_r - estimate_r));
|
||||
update(map, state, choice, state_weights[choice]);
|
||||
}
|
||||
for (int i=0; i<BIGNUM_LEN; i++){
|
||||
state[i] = _state[i];
|
||||
}
|
||||
state = _state;
|
||||
|
||||
if ((winner != 0) || (available_actions_length == 0)) {
|
||||
// printf("break\n");
|
||||
if (winner == 1) {
|
||||
|
||||
if ((winner != 0) || (available_actions_length == 0)){
|
||||
if (winner == 1){
|
||||
win += 1;
|
||||
}
|
||||
break;
|
||||
@ -235,5 +263,6 @@ void run(float* table, short* board, bool train, int times, bool plot)
|
||||
}
|
||||
|
||||
if (!train)
|
||||
printf("%d/%d, %f\%\n", win, 10000, (float)win / 10000);
|
||||
// printf("%d/%d, %f\%\n", win, 10000, (float)win/10000);
|
||||
printf("%f\n", (float)win/times);
|
||||
}
|
||||
|
||||
13
q-learning.h
13
q-learning.h
@ -1,6 +1,7 @@
|
||||
short float_argmax(float* arr, short length);
|
||||
short bot_choose_action(float* table, short* board, int state);
|
||||
short opponent_random_action(float* table, short* board, int state);
|
||||
void init_table(float* table);
|
||||
float get_estimate_reward(float* table, short* board, int state);
|
||||
void run(float* table, short* board, bool train, int times, bool plot);
|
||||
#include "hash-table.h"
|
||||
|
||||
short float_argmax(float *arr, short length);
|
||||
short bot_choose_action(struct Node **map, short *board, char *state);
|
||||
short opponent_random_action(short *board);
|
||||
float get_estimate_reward(struct Node **map, short *board, char *state);
|
||||
void run(struct Node **map, short *board, bool train, int times, bool plot);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user