Compare commits

..

10 Commits

Author SHA1 Message Date
d1279f5c6d
docs: update README 2023-06-02 23:57:46 +08:00
c3a0335ff1
Merge: replace 'tic-tac-toe' with '4 in a row'
'tic-tac-toe' version is located on the other branch.
2023-06-02 23:48:55 +08:00
cc369c7094
fix: update Makefile 2023-06-02 23:36:06 +08:00
0fb79b3e1e
feat: change q-learning method to fit 'on 4 in a row' 2023-06-02 23:34:37 +08:00
b024eec8e4
feat: calculate state hash 2023-06-02 20:19:46 +08:00
7fcadce548
feat: Big num for state representation 2023-06-02 16:52:59 +08:00
821bc5727f
feat: set up 'four in a row' enviroment 2023-06-02 15:47:02 +08:00
Ting-Jun Wang
5cf2ef7936
Merge pull request #1 from eeeXun/clang-format
GitHub-Action auto format with clang-format
2023-05-31 16:45:50 +08:00
eeeXun
7a68a06c86 style(format): run clang-format 2023-05-31 11:31:15 +08:00
eeeXun
7ba9db7f83 chore: github-action auto format with clang-format 2023-05-31 11:21:44 +08:00
14 changed files with 684 additions and 262 deletions

225
.clang-format Normal file
View File

@ -0,0 +1,225 @@
---
Language: Cpp
# BasedOnStyle: WebKit
AccessModifierOffset: -4
AlignAfterOpenBracket: DontAlign
AlignArrayOfStructures: None
AlignConsecutiveAssignments:
Enabled: false
AcrossEmptyLines: false
AcrossComments: false
AlignCompound: false
PadOperators: true
AlignConsecutiveBitFields:
Enabled: false
AcrossEmptyLines: false
AcrossComments: false
AlignCompound: false
PadOperators: false
AlignConsecutiveDeclarations:
Enabled: false
AcrossEmptyLines: false
AcrossComments: false
AlignCompound: false
PadOperators: false
AlignConsecutiveMacros:
Enabled: false
AcrossEmptyLines: false
AcrossComments: false
AlignCompound: false
PadOperators: false
AlignEscapedNewlines: Right
AlignOperands: DontAlign
AlignTrailingComments:
Kind: Never
OverEmptyLines: 0
AllowAllArgumentsOnNextLine: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortBlocksOnASingleLine: Empty
AllowShortCaseLabelsOnASingleLine: false
AllowShortEnumsOnASingleLine: true
AllowShortFunctionsOnASingleLine: All
AllowShortIfStatementsOnASingleLine: Never
AllowShortLambdasOnASingleLine: All
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: false
AlwaysBreakTemplateDeclarations: MultiLine
AttributeMacros:
- __capability
BinPackArguments: true
BinPackParameters: true
BitFieldColonSpacing: Both
BraceWrapping:
AfterCaseLabel: false
AfterClass: false
AfterControlStatement: Never
AfterEnum: false
AfterExternBlock: false
AfterFunction: true
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
BeforeCatch: false
BeforeElse: false
BeforeLambdaBody: false
BeforeWhile: false
IndentBraces: false
SplitEmptyFunction: true
SplitEmptyRecord: true
SplitEmptyNamespace: true
BreakAfterAttributes: Never
BreakAfterJavaFieldAnnotations: false
BreakArrays: true
BreakBeforeBinaryOperators: All
BreakBeforeConceptDeclarations: Always
BreakBeforeBraces: WebKit
BreakBeforeInlineASMColon: OnlyMultiline
BreakBeforeTernaryOperators: true
BreakConstructorInitializers: BeforeComma
BreakInheritanceList: BeforeColon
BreakStringLiterals: true
ColumnLimit: 0
CommentPragmas: '^ IWYU pragma:'
CompactNamespaces: false
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: false
DerivePointerAlignment: false
DisableFormat: false
EmptyLineAfterAccessModifier: Never
EmptyLineBeforeAccessModifier: LogicalBlock
ExperimentalAutoDetectBinPacking: false
FixNamespaceComments: false
ForEachMacros:
- foreach
- Q_FOREACH
- BOOST_FOREACH
IfMacros:
- KJ_IF_MAYBE
IncludeBlocks: Preserve
IncludeCategories:
- Regex: '^"(llvm|llvm-c|clang|clang-c)/'
Priority: 2
SortPriority: 0
CaseSensitive: false
- Regex: '^(<|"(gtest|gmock|isl|json)/)'
Priority: 3
SortPriority: 0
CaseSensitive: false
- Regex: '.*'
Priority: 1
SortPriority: 0
CaseSensitive: false
IncludeIsMainRegex: '(Test)?$'
IncludeIsMainSourceRegex: ''
IndentAccessModifiers: false
IndentCaseBlocks: false
IndentCaseLabels: false
IndentExternBlock: AfterExternBlock
IndentGotoLabels: true
IndentPPDirectives: None
IndentRequiresClause: true
IndentWidth: 4
IndentWrappedFunctionNames: false
InsertBraces: false
InsertNewlineAtEOF: false
InsertTrailingCommas: None
IntegerLiteralSeparator:
Binary: 0
BinaryMinDigits: 0
Decimal: 0
DecimalMinDigits: 0
Hex: 0
HexMinDigits: 0
JavaScriptQuotes: Leave
JavaScriptWrapImports: true
KeepEmptyLinesAtTheStartOfBlocks: true
LambdaBodyIndentation: Signature
LineEnding: DeriveLF
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: Inner
ObjCBinPackProtocolList: Auto
ObjCBlockIndentWidth: 4
ObjCBreakBeforeNestedBlockParam: true
ObjCSpaceAfterProperty: true
ObjCSpaceBeforeProtocolList: true
PackConstructorInitializers: BinPack
PenaltyBreakAssignment: 2
PenaltyBreakBeforeFirstCallParameter: 19
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakOpenParenthesis: 0
PenaltyBreakString: 1000
PenaltyBreakTemplateDeclaration: 10
PenaltyExcessCharacter: 1000000
PenaltyIndentedWhitespace: 0
PenaltyReturnTypeOnItsOwnLine: 60
PointerAlignment: Left
PPIndentWidth: -1
QualifierAlignment: Leave
ReferenceAlignment: Pointer
ReflowComments: true
RemoveBracesLLVM: false
RemoveSemicolon: false
RequiresClausePosition: OwnLine
RequiresExpressionIndentation: OuterScope
SeparateDefinitionBlocks: Leave
ShortNamespaceLines: 1
SortIncludes: CaseSensitive
SortJavaStaticImport: Before
SortUsingDeclarations: LexicographicNumeric
SpaceAfterCStyleCast: false
SpaceAfterLogicalNot: false
SpaceAfterTemplateKeyword: true
SpaceAroundPointerQualifiers: Default
SpaceBeforeAssignmentOperators: true
SpaceBeforeCaseColon: false
SpaceBeforeCpp11BracedList: true
SpaceBeforeCtorInitializerColon: true
SpaceBeforeInheritanceColon: true
SpaceBeforeParens: ControlStatements
SpaceBeforeParensOptions:
AfterControlStatements: true
AfterForeachMacros: true
AfterFunctionDefinitionName: false
AfterFunctionDeclarationName: false
AfterIfMacros: true
AfterOverloadedOperator: false
AfterRequiresInClause: false
AfterRequiresInExpression: false
BeforeNonEmptyParentheses: false
SpaceBeforeRangeBasedForLoopColon: true
SpaceBeforeSquareBrackets: false
SpaceInEmptyBlock: true
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: Never
SpacesInConditionalStatement: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInLineCommentPrefix:
Minimum: 1
Maximum: -1
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Latest
StatementAttributeLikeMacros:
- Q_EMIT
StatementMacros:
- Q_UNUSED
- QT_REQUIRE_VERSION
TabWidth: 8
UseTab: Never
WhitespaceSensitiveMacros:
- BOOST_PP_STRINGIZE
- CF_SWIFT_NAME
- NS_SWIFT_NAME
- PP_STRINGIZE
- STRINGIZE
...

26
.github/workflows/format.yml vendored Normal file
View File

@ -0,0 +1,26 @@
name: format
on:
push:
branches:
- master
paths:
- "**.c"
- "**.h"
jobs:
format:
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- name: Install clang-format
run: pip install clang-format
- name: Format with clang-format
run: clang-format -i *.c *.h
- name: Commit changes
uses: stefanzweifel/git-auto-commit-action@v4
with:
commit_message: "style(format): run clang-format"

View File

@ -1,5 +1,5 @@
all: a.out all: a.out
a.out: main.c enviroment.c enviroment.h q-learning.c q-learning.h constant.h a.out: main.c enviroment.c enviroment.h q-learning.c q-learning.h bignum.c bignum.h hash-table.c hash-table.h constant.h
gcc main.c enviroment.c q-learning.c -lm gcc main.c enviroment.c q-learning.c bignum.c constant.h hash-table.c -lm
run: run:
./a.out ./a.out

View File

@ -1,10 +1,10 @@
# Q-Learning-with-Tic-Tac-Toe # Q-Learning-with-Four-in-a-Row
Project for 1112 NCNU CSIE "Parallel Programming with the Message-Passing Interface" Project for 1112 NCNU CSIE "Parallel Programming with the Message-Passing Interface"
## Setup ## Setup
``` ```
git clone https://github.com/snsd0805/Q-learning-with-Tic-Tac-Toe.git git clone https://github.com/snsd0805/Q-learning-with-Four-in-a-Row.git
``` ```
### Compile ### Compile

43
bignum.c Normal file
View File

@ -0,0 +1,43 @@
#include <stdio.h>
#include <stdbool.h>
#include "bignum.h"
#include "constant.h"
struct BigNum long_to_BigNum(long long num) {
struct BigNum ans;
int temp;
for (int i=BIGNUM_LEN-1; i>=0; i--) {
temp = num % 10;
num /= 10;
ans.num[i] = (char)(temp + 48);
}
return ans;
}
struct BigNum add(struct BigNum a, struct BigNum b) {
struct BigNum ans;
short s, carry=0;
for (short i=BIGNUM_LEN-1; i>=0; i--) {
s = (a.num[i]-48) + (b.num[i]-48) + carry;
carry = s / 10;
s %= 10;
ans.num[i] = (char)(s+48);
}
return ans;
}
struct BigNum mul(struct BigNum a, int b) {
struct BigNum ans;
short s, carry=0;
for (short i=BIGNUM_LEN-1; i>=0; i--) {
s = (a.num[i]-48) * b + carry;
carry = s / 10;
s %= 10;
ans.num[i] = (char)(s+48);
// printf("index(%hd): %c\n", i, (char)(s+48));
}
return ans;
}

8
bignum.h Normal file
View File

@ -0,0 +1,8 @@
#include "constant.h"
struct BigNum {
char num[BIGNUM_LEN+1];
};
struct BigNum long_to_BigNum(long long num);
struct BigNum add(struct BigNum a, struct BigNum b);
struct BigNum mul(struct BigNum a, int b);

View File

@ -6,6 +6,12 @@
#define LAMBDA 0.9 // discount factor #define LAMBDA 0.9 // discount factor
#define STATE_NUM 19683 #define STATE_NUM 19683
#define ACTION_NUM 9 #define ACTION_NUM 7
#define EPISODE_NUM 100000 #define EPISODE_NUM 1000000
#define FIRST true #define FIRST true
#define ROW_NUM 6
#define COL_NUM 7
#define BIGNUM_LEN 22
#define TABLE_SIZE 1000000000

View File

@ -3,24 +3,31 @@
#include <assert.h> #include <assert.h>
#include "constant.h" #include "constant.h"
#include "enviroment.h" #include "enviroment.h"
#include "bignum.h"
short PATHS[8][3] = { struct BigNum POWs[42] = {
{0, 1, 2}, {3, 4, 5}, {6, 7, 8}, "0000000000000000000001", "0000000000000000000003", "0000000000000000000009", "0000000000000000000027", "0000000000000000000081",
{0, 3, 6}, {1, 4, 7}, {2, 5, 8}, "0000000000000000000243", "0000000000000000000729", "0000000000000000002187", "0000000000000000006561", "0000000000000000019683",
{0, 4, 8}, {2, 4, 6} "0000000000000000059049", "0000000000000000177147", "0000000000000000531441", "0000000000000001594323", "0000000000000004782969",
"0000000000000014348907", "0000000000000043046721", "0000000000000129140163", "0000000000000387420489", "0000000000001162261467",
"0000000000003486784401", "0000000000010460353203", "0000000000031381059609", "0000000000094143178827", "0000000000282429536481",
"0000000000847288609443", "0000000002541865828329", "0000000007625597484987", "0000000022876792454961", "0000000068630377364883",
"0000000205891132094649", "0000000617673396283947", "0000001853020188851841", "0000005559060566555523", "0000016677181699666569",
"0000050031545098999707", "0000150094635296999121", "0000450283905890997363", "0001350851717672992089", "0004052555153018976267",
"0012157665459056928801", "0036472996377170786403"
}; };
/* /*
Reset the game, clear the chessboard. Reset the game, clear the chessboard.
Args: Args:
- short *board (array's address): chessboard's status - short *board (array's start address): chessboard's status
Results: Results:
- None, set all blocks on the chessboard to zero. - None, set all blocks on the chessboard to zero.
*/ */
void reset(short* board){ void reset(short* board){
for (short i=0; i<9; i++) for (short i=0; i<(ROW_NUM*COL_NUM); i++)
board[i] = 0; board[i] = 0;
} }
@ -35,20 +42,21 @@ void reset(short* board){
*/ */
void show(short *board){ void show(short *board){
short loc; short loc;
printf("┼───┼───┼───┼\n"); for (short i=0; i<COL_NUM; i++){
for (short i=0; i<3; i++){ printf("%d ", i);
printf("");
for (short j=0; j<3; j++){
loc = 3*i+j;
if (board[loc] == 0)
printf("");
else if (board[loc] == BOT_SYMBOL)
printf("○ │ ");
else
printf("✕ │ ");
} }
printf("\n"); printf("\n");
printf("┼───┼───┼───┼\n"); for (short i=(ROW_NUM*COL_NUM-1); i>=0; i--){
if (board[i] == BOT_SYMBOL) {
printf("");
} else if(board[i] == OPPONENT_SYMBOL) {
printf("");
} else {
printf("");
}
if (i%COL_NUM == 0){
printf("\n");
}
} }
printf("\n\n"); printf("\n\n");
} }
@ -64,14 +72,36 @@ void show(short *board){
Results: Results:
- None. All available actions are saved into "result" and the number of actions is saved in "length" - None. All available actions are saved into "result" and the number of actions is saved in "length"
*/ */
void get_available_actions(short *board, short *result, short *length){ void get_available_actions(short *board, short *result, short *length){
short index = 0; short index = 0;
for (int i=0; i<9; i++) for (int i=0; i<COL_NUM; i++)
if (board[i] == 0) if (board[(ROW_NUM*COL_NUM-1)-i] == 0)
result[index++] = i; result[index++] = i;
*length = index; *length = index;
} }
/*
Get value in the board with validation.
Args:
- short *board (array's start pointer): chessboard's status
- short row (integer): loc's row number
- short col (integer): loc's col number
Results:
- short value (integer): means the value in chessboard[row][col]
*/
short get_loc_status(short *board, short row, short col) {
if ((row >= ROW_NUM) || (row < 0)) {
return -1;
}
if ((col >= COL_NUM) || (col < 0)) {
return -1;
}
return board[row*COL_NUM+col];
}
/* /*
Return winner's number; Return winner's number;
@ -80,13 +110,57 @@ void get_available_actions(short *board, short *result, short *length){
Results: Results:
- short winner_number(integer): winner's number, 0 for no winner now, 1 for Bot, 2 for opponent - short winner_number(integer): winner's number, 0 for no winner now, 1 for Bot, 2 for opponent
board's coodinate diagram
^
| 5
| 4
| 3
| 2
| 1
| 0
<-----------------------------
6 5 4 3 2 1 0 |
*/ */
short get_winner(short *board){ short get_winner(short *board){
int a, b, c; short a, b, c, d;
for (int i=0; i<8; i++){ for (short i=0; i<ROW_NUM; i++){
a = PATHS[i][0]; b = PATHS[i][1]; c = PATHS[i][2]; for (short j=0; j<COL_NUM; j++){
if ((board[a] == board[b]) && (board[b] == board[c]) && (board[a] != 0)){ // horizontal
return board[a]; a = get_loc_status(board, i, j);
b = get_loc_status(board, i, j+1);
c = get_loc_status(board, i, j+2);
d = get_loc_status(board, i, j+3);
if ((a == b) && (b == c) && (c == d) && (a!=0)) {
return a;
}
// vertical
a = get_loc_status(board, i, j);
b = get_loc_status(board, i+1, j);
c = get_loc_status(board, i+2, j);
d = get_loc_status(board, i+3, j);
if ((a == b) && (b == c) && (c == d) && (a!=0)) {
return a;
}
// slash (/)
a = get_loc_status(board, i, j);
b = get_loc_status(board, i+1, j-1);
c = get_loc_status(board, i+2, j-2);
d = get_loc_status(board, i+3, j-3);
if ((a == b) && (b == c) && (c == d) && (a!=0)) {
return a;
}
// backslash (\)
a = get_loc_status(board, i, j);
b = get_loc_status(board, i+1, j+1);
c = get_loc_status(board, i+2, j+2);
d = get_loc_status(board, i+3, j+3);
if ((a == b) && (b == c) && (c == d) && (a!=0)) {
return a;
}
} }
} }
return 0; return 0;
@ -97,19 +171,54 @@ short get_winner(short *board){
Args: Args:
- short *board (array's address): chessboard's status - short *board (array's address): chessboard's status
- char *hash (a string): size is BIGNUM_LEN, the hash will be wrote here
Results: Results:
- int hash (integer): chessboard's status in i-th block * pow(3, i) - None.
*/ */
int state_hash(short *board){ void state_hash(short *board, char *hash){
int base, hash = 0; struct BigNum sum, temp;
for (int i=0; i<9; i++){ for (short i=0; i<BIGNUM_LEN; i++){
base = pow(3, i); sum.num[i] = '0';
hash += (base * board[i]);
}
return hash;
} }
for (short i=0; i<(ROW_NUM*COL_NUM); i++) {
// printf("MUL:\n");
// printf("%s\n", POWs[i].num);
temp = mul(POWs[i], board[i]);
// printf("%s\n\n", temp.num);
// printf("ADD:\n");
// printf("%s\n", sum.num);
// printf("%s\n", temp.num);
sum = add(sum, temp);
// printf("%s\n\n", sum.num);
}
for (int i=0; i<BIGNUM_LEN; i++){
hash[i] = sum.num[i];
}
}
/*
Fall the chess on the board.
Args:
- short *board: chessboard
- struct action *a (struct pointer): action's loc & player
Results:
- None. Fall chess on the chessboard
*/
void fall(short *board, struct action *a) {
short *ptr = (board + ROW_NUM * COL_NUM - 1 - (a->loc));
while ((*ptr == 0) && (ptr>=board)) {
// printf("%d ", *ptr);
ptr -= COL_NUM;
}
*(ptr+COL_NUM) = a->player;
}
/* /*
Act on the chessboard. Act on the chessboard.
@ -117,7 +226,7 @@ int state_hash(short *board){
Args: Args:
- short *board (array's address): chessboards' status - short *board (array's address): chessboards' status
- struct action *a (a action's pointer): include player & choose loc - struct action *a (a action's pointer): include player & choose loc
- int *state (pointer): for return. To save the chessboard's state hash which after doing this action - char *state (a string): for return. To save the chessboard's state hash which after doing this action
- float *reward (pointer): for return. To save the number of rewards which the player gets after doing this action. - float *reward (pointer): for return. To save the number of rewards which the player gets after doing this action.
- float *opponent_reward (pointer): for return. To save the number of rewards which the opponents gets after the player doing this action. - float *opponent_reward (pointer): for return. To save the number of rewards which the opponents gets after the player doing this action.
- short *winner (pointer): for return. To save the winner in this action. If haven't finish, it will be zero. - short *winner (pointer): for return. To save the winner in this action. If haven't finish, it will be zero.
@ -125,21 +234,20 @@ int state_hash(short *board){
Results: Results:
- None. Save in state & reward & winner - None. Save in state & reward & winner
*/ */
void act(short *board, struct action *a, int *state, float *reward, float *opponent_reward, short *winner){ void act(short *board, struct action *a, char *state, float *reward, float *opponent_reward, short *winner){
// printf("Act( player=%d, action=%d )\n", a->player, a->loc); // printf("Act( player=%d, action=%d )\n", a->player, a->loc);
assert(board[a->loc] == 0); assert(board[(ROW_NUM*COL_NUM-1)-(a->loc)] == 0);
board[a->loc] = a->player;
fall(board, a);
*winner = get_winner(board); *winner = get_winner(board);
*state = state_hash(board); state_hash(board, state);
if (*winner == a->player){ if (*winner == a->player){
*reward = 1.0; *reward = 1.0;
*opponent_reward = -1.0; *opponent_reward = -1.0;
} } else if (*winner != 0) {
else if(*winner != 0){
*reward = -1.0; *reward = -1.0;
*opponent_reward = 1.0; *opponent_reward = 1.0;
} } else {
else{
*reward = 0; *reward = 0;
*opponent_reward = 0; *opponent_reward = 0;
} }

View File

@ -7,5 +7,5 @@ void reset(short* board);
void show(short *board); void show(short *board);
void get_available_actions(short *board, short *result, short *length); void get_available_actions(short *board, short *result, short *length);
short get_winner(short *board); short get_winner(short *board);
int state_hash(short *board); void state_hash(short *board, char *hash);
void act(short *board, struct action *a, int *state, float *reward, float *opponent_reward, short *winner); void act(short *board, struct action *a, char *state, float *reward, float *opponent_reward, short *winner);

View File

@ -3,14 +3,8 @@
#include <string.h> #include <string.h>
#include <assert.h> #include <assert.h>
#include <time.h> #include <time.h>
#include "hash-table.h"
#define TABLE_SIZE 10 #include "constant.h"
struct Node {
char key[48];
int value;
struct Node *next;
};
long long hash_function(char *key) { long long hash_function(char *key) {
long long hash = 0; long long hash = 0;
@ -20,115 +14,69 @@ long long hash_function(char *key) {
return hash ; return hash ;
} }
void insert(struct Node **table, char *key, int value) { void insert(struct Node **table, char *key) {
long long hash = hash_function(key); long long hash = hash_function(key);
printf("Hash: %lli\n", hash);
struct Node *node = malloc(sizeof(struct Node)); struct Node *node = malloc(sizeof(struct Node));
struct Node *temp, *past; struct Node *temp, *past;
strcpy(node->key, key); strcpy(node->key, key);
node->value = value; // init
for (short i=0; i<ACTION_NUM; i++){
node->value[i] = 0.0;
}
node->next = NULL; node->next = NULL;
if (table[hash] == NULL){ if (table[hash] == NULL){
table[hash] = node; table[hash] = node;
printf("Create.\n");
} else { } else {
printf("Add.\n");
temp = table[hash]; temp = table[hash];
past = NULL; past = NULL;
while(temp != NULL){ while(temp != NULL){
assert(temp->key != key); assert(strcmp(temp->key, key)!=0);
printf("%s -> ", temp->key);
past = temp; past = temp;
temp = temp->next; temp = temp->next;
} }
printf("\n");
past->next = node; past->next = node;
} }
} }
void long_to_str(long long num, char *s, int length) { void search(struct Node **table, char *key, bool *find, float *ans) {
int temp;
for (int i=length-1; i>=0; i--) {
temp = num % 10;
num /= 10;
s[i] = (char)(temp + 48);
}
}
int search(struct Node **table, char *key) {
long long hash = hash_function(key); long long hash = hash_function(key);
struct Node *temp, *past; struct Node *temp, *past;
*find = false;
if (table[hash] == NULL){ if (table[hash] != NULL){
return -1;
} else {
temp = table[hash]; temp = table[hash];
past = NULL; past = NULL;
while(temp != NULL){ while(temp != NULL){
// printf("%s - %s\n", temp->key, key);
if (strcmp(temp->key, key) == 0){ if (strcmp(temp->key, key) == 0){
return temp->value; *find = true;
for (short i=0; i<ACTION_NUM; i++){
ans[i] = temp->value[i];
} }
past = temp;
temp = temp->next;
}
return -1;
}
}
void update(struct Node **table, char *key, int value) {
long long hash = hash_function(key);
struct Node *temp, *past;
temp = table[hash];
past = NULL;
while(temp != NULL){
if (strcmp(temp->key, key) == 0){
temp->value = value;
break; break;
} }
past = temp; past = temp;
temp = temp->next; temp = temp->next;
} }
} }
int main(){
struct Node ** table; // pointer to pointer
int size;
srand(time(NULL));
table = malloc(TABLE_SIZE * sizeof(struct Node*));
for (int i=0; i<TABLE_SIZE; i++){
table[i] = NULL;
} }
long long a = 1234567890; void update(struct Node **table, char *key, short action, float value) {
char s[21]; long long hash = hash_function(key);
for (int i=0; i<50; i++){ struct Node *temp, *past;
a = (long long)rand(); assert(table[hash]!=NULL);
printf("%lli\n", a);
long_to_str(a, s, 20);
printf("%s\n", s);
insert(table, s, i);
printf("\n");
}
int ans; temp = table[hash];
while (1) { past = NULL;
printf("> "); while(temp != NULL){
scanf("%lli", &a); if (strcmp(temp->key, key) == 0){
printf("HERE\n"); temp->value[action] = value;
long_to_str(a, s, 20); break;
printf("HERE\n"); }
past = temp;
update(table, s, 100); temp = temp->next;
ans = search(table, s);
printf("%d\n\n", ans);
} }
// long long a = hash_function("9999999999999");
// printf("%lli\n", a);
} }

13
hash-table.h Normal file
View File

@ -0,0 +1,13 @@
#include "constant.h"
#include <stdbool.h>
struct Node {
char key[BIGNUM_LEN+1];
float value[ACTION_NUM];
struct Node *next;
};
long long hash_function(char *key);
void insert(struct Node **table, char *key);
void search(struct Node **table, char *key, bool *find, float *ans);
void update(struct Node **table, char *key, short action, float value);

28
main.c
View File

@ -1,19 +1,27 @@
#include <stdio.h>
#include <time.h>
#include <stdlib.h>
#include <stdbool.h>
#include "constant.h" #include "constant.h"
#include "enviroment.h" #include "enviroment.h"
#include "q-learning.h" #include "q-learning.h"
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
int main(){ int main(){
short board[9]= {0}; // tic tac toe's chessboard short board[ROW_NUM][COL_NUM]= {0};
float table[STATE_NUM][ACTION_NUM]; // q-learning table short winner;
struct Node ** map; // pointer to pointer, hash table
bool find;
float state[ACTION_NUM];
srand(time(NULL)); srand(time(NULL));
init_table(&table[0][0]);
run(&table[0][0], board, false, 10000, false); // init hash table
run(&table[0][0], board, true, EPISODE_NUM, false); map = malloc(TABLE_SIZE * sizeof(struct Node*));
run(&table[0][0], board, false, 10000, false); for (int i=0; i<TABLE_SIZE; i++){
map[i] = NULL;
}
run(map, &board[0][0], false, 10000, false);
run(map, &board[0][0], true, EPISODE_NUM, false);
run(map, &board[0][0], false, 10000, true);
} }

View File

@ -1,11 +1,12 @@
#include <stdio.h>
#include <float.h> #include <float.h>
#include <stdbool.h>
#include <limits.h> #include <limits.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include "constant.h" #include "constant.h"
#include "enviroment.h" #include "enviroment.h"
#include "hash-table.h"
/* /*
Return the index with the max value in the array Return the index with the max value in the array
@ -17,7 +18,8 @@
Results: Results:
- short index (integer): the index with the max value - short index (integer): the index with the max value
*/ */
short float_argmax(float *arr, short length){ short float_argmax(float* arr, short length)
{
float ans = -1, max = -FLT_MAX; float ans = -1, max = -FLT_MAX;
for (short i = 0; i < length; i++) { for (short i = 0; i < length; i++) {
if (arr[i] > max) { if (arr[i] > max) {
@ -28,7 +30,6 @@ short float_argmax(float *arr, short length){
return ans; return ans;
} }
/* /*
Choose the next action with Epsilon-Greedy. Choose the next action with Epsilon-Greedy.
EPSILON means the probability to choose the best action in this state from Q-Table. EPSILON means the probability to choose the best action in this state from Q-Table.
@ -37,28 +38,41 @@ short float_argmax(float *arr, short length){
Args: Args:
- short *table (array's address): state table for Q-Learning - short *table (array's address): state table for Q-Learning
- short *board (array's address): chessboards' status - short *board (array's address): chessboards' status
- int state (integer, state hash): hash for board's status - char *state (string, state hash): hash for board's status
Results: Results:
- short best_choice - short best_choice
*/ */
short bot_choose_action(float *table, short *board, int state){ short bot_choose_action(struct Node **map, short *board, char *state){
// get available actions for choosing // get available actions for choosing
short available_actions[9]; short available_actions[ACTION_NUM];
short available_actions_length; short available_actions_length;
get_available_actions(board, available_actions, &available_actions_length); get_available_actions(board, available_actions, &available_actions_length);
// use argmax() to find the best choise, // use argmax() to find the best choise,
// first we should build an available_actions_state array for saving the state for all available choise. // first we should build an available_actions_state array for saving the state for all available choise.
float available_actions_state[9]; float available_actions_state[ACTION_NUM];
short available_actions_state_index[9]; short available_actions_state_index[ACTION_NUM];
short available_actions_state_length, index = 0; short available_actions_state_length, index = 0;
short temp_index, best_choice; short temp_index, best_choice;
bool zeros = true; bool zeros = true;
bool find;
float state_weights[ACTION_NUM];
// find weights in the hash table
search(map, state, &find, state_weights);
if (!find) {
for (short i=0; i<ACTION_NUM; i++){
state_weights[i] = 0.0;
}
}
// get the best choice
for (short i=0; i<available_actions_length; i++){ for (short i=0; i<available_actions_length; i++){
temp_index = available_actions[i]; temp_index = available_actions[i];
available_actions_state[index] = *(table + state * ACTION_NUM + temp_index);
available_actions_state[index] = state_weights[temp_index];
if (available_actions_state[index] != 0.0){ if (available_actions_state[index] != 0.0){
zeros = false; zeros = false;
} }
@ -83,17 +97,15 @@ short bot_choose_action(float *table, short *board, int state){
Opponent random choose a action to do. Opponent random choose a action to do.
Args: Args:
- short *table (array's address): state table for Q-Learning
- short *board (array's address): chessboards' status - short *board (array's address): chessboards' status
- int state (integer, state hash): hash for board's status
Results: Results:
- short choice (integer): random, -1 means no available action to choose - short choice (integer): random, -1 means no available action to choose
*/ */
short opponent_random_action(float *table, short *board, int state){ short opponent_random_action(short *board){
// get available actions for choosing // get available actions for choosing
short available_actions[9]; short available_actions[ACTION_NUM];
short available_action_length; short available_action_length;
get_available_actions(board, available_actions, &available_action_length); get_available_actions(board, available_actions, &available_action_length);
@ -109,22 +121,24 @@ short opponent_random_action(float *table, short *board, int state){
return choice; return choice;
} }
/* // Use Hash Table, so we needn't initilize Q-Table
Inilialize the Q-Table //
// /*
// Inilialize the Q-Table
Args: // Args:
- float *table (two-dim array's start address) // - float *table (two-dim array's start address)
Results: // Results:
- None. // - None.
*/ // */
void init_table(float *table){ // void init_table(float *table){
for (int i=0; i<STATE_NUM; i++){ // for (int i=0; i<STATE_NUM; i++){
for (int j=0; j<ACTION_NUM; j++){ // for (int j=0; j<ACTION_NUM; j++){
*(table + i * ACTION_NUM + j) = 0; // *(table + i * ACTION_NUM + j) = 0;
} // }
} // }
} // }
/* /*
Give the chessboard & state, it will return the max reward with the best choice Give the chessboard & state, it will return the max reward with the best choice
@ -137,14 +151,24 @@ void init_table(float *table){
Results: Results:
- int max_reward - int max_reward
*/ */
float get_estimate_reward(float *table, short *board, int state){ float get_estimate_reward(struct Node **map, short *board, char *state){
short available_actions[9]; short available_actions[ACTION_NUM];
short available_action_length; short available_action_length;
get_available_actions(board, available_actions, &available_action_length); get_available_actions(board, available_actions, &available_action_length);
float available_actions_state[9]; // find weights in the hash table
float state_weights[ACTION_NUM];
bool find;
search(map, state, &find, state_weights);
if (!find) {
for (short i=0; i<ACTION_NUM; i++){
state_weights[i] = 0.0;
}
}
float available_actions_state[ACTION_NUM];
for (short i=0; i<available_action_length; i++){ for (short i=0; i<available_action_length; i++){
available_actions_state[i] = *(table + state * ACTION_NUM + available_actions[i]); // table[state][available_actions[i]] available_actions_state[i] = state_weights[available_actions[i]]; // table[state][available_actions[i]]
} }
short ans_index; short ans_index;
@ -165,37 +189,46 @@ float get_estimate_reward(float *table, short *board, int state){
Results: Results:
- None - None
*/ */
void run(float *table, short *board, bool train, int times, bool plot){ void run(struct Node **map, short *board, bool train, int times, bool plot){
short available_actions[9]; short available_actions[ACTION_NUM];
short available_actions_length; short available_actions_length;
short winner; short winner;
short choice, opponent_choice; short choice, opponent_choice;
int state, _state; char state[BIGNUM_LEN], _state[BIGNUM_LEN];
float estimate_r, estimate_r_, real_r, r, opponent_r; float estimate_r, estimate_r_, real_r, r, opponent_r;
struct action a; struct action a;
float state_weights[ACTION_NUM];
bool find;
int win = 0; int win = 0;
for (int episode = 0; episode < times; episode++) { for (int episode = 0; episode < times; episode++) {
reset(board); reset(board);
state = state_hash(board); state_hash(board, state);
while (1){ while (1){
// bot choose the action // bot choose the action
choice = bot_choose_action(table, board, state); choice = bot_choose_action(map, board, state);
a.loc = choice; a.loc = choice;
a.player = BOT_SYMBOL; a.player = BOT_SYMBOL;
estimate_r = *(table + state * ACTION_NUM + choice); search(map, state, &find, state_weights);
act(board, &a, &_state, &r, &opponent_r, &winner); if (!find) {
for (short i=0; i<ACTION_NUM; i++){
state_weights[i] = 0.0;
}
if (train)
insert(map, state);
}
estimate_r = state_weights[choice];
act(board, &a, _state, &r, &opponent_r, &winner);
if (plot) show(board); if (plot) show(board);
// opponent random // // opponent random
if (winner == 0){ if (winner == 0){
opponent_choice = opponent_random_action(table, board, state_hash(board)); opponent_choice = opponent_random_action(board);
if (opponent_choice != -1){ if (opponent_choice != -1){
a.loc = opponent_choice; a.loc = opponent_choice;
a.player = OPPONENT_SYMBOL; a.player = OPPONENT_SYMBOL;
act(board, &a, &_state, &opponent_r, &r, &winner); act(board, &a, _state, &opponent_r, &r, &winner);
if (plot) show(board); if (plot) show(board);
} }
} }
@ -208,17 +241,19 @@ void run(float *table, short *board, bool train, int times, bool plot){
} }
real_r = r; real_r = r;
} else { } else {
estimate_r_ = get_estimate_reward(table, board, _state); estimate_r_ = get_estimate_reward(map, board, _state);
real_r = r + LAMBDA * estimate_r_; real_r = r + LAMBDA * estimate_r_;
} }
if (train){ if (train){
// printf("update"); state_weights[choice] += (LR * (real_r - estimate_r));
*(table + state * ACTION_NUM + choice) += ( LR * (real_r - estimate_r) ); // table[state][choice] += LR * (real_r - estimate_r) update(map, state, choice, state_weights[choice]);
} }
state = _state; for (int i=0; i<BIGNUM_LEN; i++){
state[i] = _state[i];
}
if ((winner != 0) || (available_actions_length == 0)){ if ((winner != 0) || (available_actions_length == 0)){
// printf("break\n");
if (winner == 1){ if (winner == 1){
win += 1; win += 1;
} }
@ -228,5 +263,6 @@ void run(float *table, short *board, bool train, int times, bool plot){
} }
if (!train) if (!train)
printf("%d/%d, %f\%\n", win, 10000, (float)win/10000); // printf("%d/%d, %f\%\n", win, 10000, (float)win/10000);
printf("%f\n", (float)win/times);
} }

View File

@ -1,6 +1,7 @@
#include "hash-table.h"
short float_argmax(float *arr, short length); short float_argmax(float *arr, short length);
short bot_choose_action(float *table, short *board, int state); short bot_choose_action(struct Node **map, short *board, char *state);
short opponent_random_action(float *table, short *board, int state); short opponent_random_action(short *board);
void init_table(float *table); float get_estimate_reward(struct Node **map, short *board, char *state);
float get_estimate_reward(float *table, short *board, int state); void run(struct Node **map, short *board, bool train, int times, bool plot);
void run(float *table, short *board, bool train, int times, bool plot);