PokeFusion/main.py
2023-12-24 20:02:07 +08:00

370 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import open3d as o3d
import cv2
from pokemon import Pokemon
import random
PAPER_WCS_POINT = np.array([[0, 0, 0], [18.5, 0, 0], [0, 26, 0], [18.5, 26, 0]], dtype=np.float32)
CAMERA_MATRIX = np.load('camera_parameters.npy', allow_pickle=True).item()['K']
DISTORTION_MATRIX = np.load('camera_parameters.npy', allow_pickle=True).item()['dist']
# types = ['fire', 'water', 'electric', 'rock', 'grass', 'ice', 'steel']
# animals = ['cat', 'dog', 'rat', 'rabbit', 'dragon', 'duck', 'turtle', 'butterfly', 'monkey', 'bee', 'fox', 'flower', 'horse', 'jellyfish', 'snake', 'Tyrannosaurus', 'dinosaur', 'fish', 'whale', 'bat', 'bear', 'deer', 'pig', 'eagle', 'chicken']
types = ['fire']
animals = ['cat']
def get_new_approx(approx, previous_approx):
def distance(point, pre_point):
point = point[0]
pre_point = pre_point[0]
d = ((point[0]-pre_point[0]) ** 2) + ((point[1]-pre_point[1]) ** 2)
return np.sqrt(d)
new_approx = []
for pre_point in previous_approx:
min_d, min_index = 1e9, -100
for index, point in enumerate(approx):
d = distance(point, pre_point)
if d < min_d and d<=500:
min_d = d
min_index = index
if min_index != -100:
np.delete(approx, min_index)
new_approx.append(approx[min_index])
if len(new_approx) == 4:
return new_approx
else:
return previous_approx
def initial_approx(approx):
distance = []
for point in approx:
x, y = point[0]
distance.append(x**2+y**2)
min_d, max_d = 1e9, -1
min_index, max_index = -100, -100
for index, d in enumerate(distance):
if d < min_d:
min_d = d
min_index = index
if d > max_d:
max_d = d
max_index = index
new_approx = []
remain_index = list(range(4))
remain_index.remove(max_index)
remain_index.remove(min_index)
if approx[remain_index[0]][0][0] > approx[remain_index[1]][0][0]: # remain 的第一個點的 x 大於第二個點的,代表他在右上角
second_index = remain_index[0]
third_index = remain_index[1]
else:
second_index = remain_index[1]
third_index = remain_index[0]
new_approx.append(approx[min_index])
new_approx.append(approx[second_index])
new_approx.append(approx[third_index])
new_approx.append(approx[max_index])
return new_approx
def new_pokemon():
new_type = types[ random.randint(0, len(types)-1) ]
new_animal = animals[ random.randint(0, len(animals)-1) ]
print(new_type, new_animal)
return Pokemon(new_type, new_animal)
def plot_corner_points(frame, approx):
y, x = approx[0][0]
cv2.circle(frame, (x, y), 15, (0, 0, 255), -1) # 在角點位置畫紅色圓圈
y, x = approx[1][0]
cv2.circle(frame, (x, y), 30, (0, 0, 255), -1) # 在角點位置畫紅色圓圈
y, x = approx[2][0]
cv2.circle(frame, (x, y), 15, (255, 0, 0), -1) # 在角點位置畫紅色圓圈
y, x = approx[3][0]
cv2.circle(frame, (x, y), 30, (255, 0, 0), -1) # 在角點位置畫紅色圓圈
def visualization(inverse_matrix_M, pokemon):
box_size = 0.3
# initialize visualizer
vis = o3d.visualization.Visualizer()
vis.create_window()
# 只畫紙張
paper_points = np.array(PAPER_WCS_POINT)
paper_lines = np.array([[0, 1], [0, 2], [1, 3], [2, 3]])
paper_line_colors = np.array([[255, 0, 0], [255, 0, 0], [255, 0, 0], [255, 0, 0]])
paper_triangles = np.array([[0, 1, 2], [1, 2, 3], [2, 1, 0], [3, 2, 1]])
# 畫紙張所有線段
paper_line_set = o3d.geometry.LineSet()
paper_line_set.points = o3d.utility.Vector3dVector(paper_points)
paper_line_set.lines = o3d.utility.Vector2iVector(paper_lines)
paper_line_set.colors = o3d.utility.Vector3dVector(paper_line_colors)
vis.add_geometry(paper_line_set)
coordinate_line_set = o3d.geometry.LineSet()
coordinate_line_set.points = o3d.utility.Vector3dVector(np.array([[0, 0, 0],[10, 0, 0], [0, 10, 0], [0, 0, 10]]))
coordinate_line_set.lines = o3d.utility.Vector2iVector(np.array([[0, 1], [0, 2], [0, 3]]))
coordinate_line_set.colors = o3d.utility.Vector3dVector(np.array([[255, 0, 0], [0, 255, 0],[ 0, 0, 255]]))
vis.add_geometry(coordinate_line_set)
# 化紙張三角形
paper_triangle_set = o3d.geometry.TriangleMesh()
paper_triangle_set.vertices = o3d.utility.Vector3dVector(paper_points)
paper_triangle_set.triangles = o3d.utility.Vector3iVector(paper_triangles)
vis.add_geometry(paper_triangle_set)
# Pokemon model
pcd = o3d.geometry.PointCloud()
pokemon_point = pokemon.get_position()[:, :3]
pokemon_point[:, 2] = -pokemon_point[:, 2]
pcd.points = o3d.utility.Vector3dVector(pokemon_point)
pcd.colors = o3d.utility.Vector3dVector(pokemon.get_colors()[:, ::-1])
vis.add_geometry(pcd)
# 下面是畫相機
# 找 金字塔端點
points = []
lines = []
line_colors = []
triangles = []
for m in inverse_matrix_M: # 不同位置的 M
start_index = len(points)
for corner in [
[0, 0, 0, 1],
[-box_size, -box_size, 1, 1],
[-box_size, box_size, 1, 1],
[box_size, box_size, 1, 1],
[box_size, -box_size, 1, 1],
]:
location = m @ np.array(corner)
location[2] = -location[2]
points.append(location)
# 方框的線
lines.append(np.array([ start_index+1, start_index+2 ]))
lines.append(np.array([ start_index+2, start_index+3 ]))
lines.append(np.array([ start_index+3, start_index+4 ]))
lines.append(np.array([ start_index+4, start_index+1 ]))
# 相機點到方框四點的線
lines.append(np.array([ start_index, start_index+1 ]))
lines.append(np.array([ start_index, start_index+2 ]))
lines.append(np.array([ start_index, start_index+3 ]))
lines.append(np.array([ start_index, start_index+4 ]))
# 八條線都是黑色
for i in range(8):
line_colors.append(np.array([0, 0, 0]))
# trajectory 紅線
if start_index != 0:
lines.append(np.array([ start_index, start_index-5 ]))
line_colors.append(np.array([1, 0, 0]))
# 兩塊三角形組成方框
triangles.append([ start_index+1, start_index+2, start_index+3 ])
triangles.append([ start_index+1, start_index+3, start_index+4 ])
triangles.append([ start_index+3, start_index+2, start_index+1 ])
triangles.append([ start_index+4, start_index+3, start_index+1 ])
points = np.array(points)[:, :-1]
lines = np.array(lines)
line_colors = np.array(line_colors)
triangles = np.array(triangles)
# 畫所有線段
line_set = o3d.geometry.LineSet()
line_set.points = o3d.utility.Vector3dVector(points)
line_set.lines = o3d.utility.Vector2iVector(lines)
line_set.colors = o3d.utility.Vector3dVector(line_colors)
vis.add_geometry(line_set)
# 化三角形
triangle_set = o3d.geometry.TriangleMesh()
triangle_set.vertices = o3d.utility.Vector3dVector(points)
triangle_set.triangles = o3d.utility.Vector3iVector(triangles)
vis.add_geometry(triangle_set)
vis.run()
def render_pokemon(frame, M, camera_position, pokemon):
pokemon_points = pokemon.get_position()
pokemon_colors = pokemon.get_colors()
distance = np.linalg.norm(pokemon_points - camera_position, axis=1)
sorted_indices = np.argsort(-distance)
pokemon_points = pokemon_points[sorted_indices]
pokemon_colors = pokemon_colors[sorted_indices]
M = CAMERA_MATRIX.dot(M)
pokemon_points = (M @ (pokemon_points.T)).T
for point , color in zip(pokemon_points, pokemon_colors):
color = [int(i*255) for i in color]
color = np.array(color, dtype=int)
color = color[0:3]
point2D = point[:2] / point[2]
point2D = np.int_(point2D[::-1])
if 0 <= point2D[0] and point2D[0] < frame.shape[1] and 0 <= point2D[1] and point2D[1] < frame.shape[0]:
frame = cv2.circle(frame, tuple(point2D), 10,color.astype(int).tolist(), thickness=-1)
frame = cv2.putText(
frame,
"{} type, {} pokemon".format(pokemon.type, pokemon.animal),
(350, 50),
cv2.FONT_HERSHEY_SIMPLEX,
2,
(255, 255, 0),
2
)
return frame
if __name__ == '__main__':
# 設定連接到 Android 手機的相機
cap = cv2.VideoCapture(0) # 0 表示第一個相機(通常是後置相機),若是前置相機,可以使用 1
# cap = cv2.VideoCapture('demo1.mp4')
previous_approx = []
pokemon = new_pokemon()
rotation_vectors = []
translation_vectors = []
while True:
ret, frame = cap.read() # 讀取影片幀
if not ret:
break
# 預處理影像(例如:轉為灰度)
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
gray = cv2.GaussianBlur(gray, (5, 5), 0)
# 偵測邊緣
edges = cv2.Canny(gray, 30, 90)
dilate_kernel = np.ones((5, 5), np.uint8)
dilate_edges = cv2.dilate(edges, dilate_kernel, iterations=1)
# 偵測輪廓
contours, _ = cv2.findContours(dilate_edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours != []:
# 找到最大的輪廓
max_contour = max(contours, key=cv2.contourArea)
# 找到輪廓的近似多邊形
epsilon = 0.05 * cv2.arcLength(max_contour, True)
approx = cv2.approxPolyDP(max_contour, epsilon, True)
approx = approx[:, :, ::-1]
# 繪製多邊形
if len(approx) == 4: # 確保是四個角點
# 比對 previous_approx確認現在找到的四個點是紙張上的哪一點
if previous_approx == []:
print("INITIAL")
previous_approx = initial_approx(approx)
new_approx = get_new_approx(approx, previous_approx)
previous_approx = new_approx
paper_ccs_point = np.concatenate(new_approx, axis=0, dtype=np.float32)
# 畫邊緣 & 四點
# plot_corner_points(frame, new_approx)
# cv2.drawContours(frame, [approx[:, :, ::-1]], -1, (0, 255, 0), 2) # 繪製輪廓
# 算 rotaion & translation
success, rotation_vector, translation_vector = cv2.solvePnP(PAPER_WCS_POINT, paper_ccs_point, \
CAMERA_MATRIX, DISTORTION_MATRIX)
rotation_matrix, _ = cv2.Rodrigues(rotation_vector)
rotation_vectors.append(rotation_matrix)
translation_vectors.append(translation_vector)
'''
print("R:", rotation_matrix)
print("t:", translation_vector)
print()
'''
R = rotation_matrix
t = translation_vector
rt = np.concatenate((R, t), axis=1)
# print('rt: ', rt)
M = np.concatenate((rt, [[0, 0, 0, 1]]), axis=0)
M_inv = np.linalg.inv(M)
camera_position = M_inv @ np.array([0, 0, 0, 1])
# camera_position[2] = -camera_position[2]
# print("CAMERA: ", camera_position)
frame = render_pokemon(frame, rt, camera_position, pokemon)
# 顯示結果
cv2.namedWindow('Paper Detection(edge)', 0)
cv2.imshow('Paper Detection(edge)', edges)
cv2.namedWindow('Paper Detection(dilate edge)', 0)
cv2.imshow('Paper Detection(dilate edge)', dilate_edges)
cv2.namedWindow('Paper Detection', 0)
cv2.imshow('Paper Detection', frame)
# 按下 'q' 鍵退出迴圈
key = cv2.waitKey(17) & 0xFF
if key == ord('q'): # 等待 33ms (1秒 = 1000ms, 1秒顯示幀)
break
if key == ord(' '):
pokemon = new_pokemon()
previous_approx = []
if key == ord('w'):
pokemon.walk_backward()
if key == ord('s'):
pokemon.walk_forward()
if key == ord('a'):
pokemon.walk_left()
if key == ord('d'):
pokemon.walk_right()
if key == ord('o'):
pokemon.rotate_right()
if key == ord('p'):
pokemon.rotate_left()
cap.release()
cv2.destroyAllWindows()
# 畫 camera pose
'''
M_inv = []
for index in range(len(rotation_vectors)):
R = rotation_vectors[index]
t = translation_vectors[index]
rt = np.concatenate((R, t), axis=1)
M = np.concatenate((rt, [[0, 0, 0, 1]]), axis=0)
M_inv.append(np.linalg.inv(M))
# load Pokemon points
visualization(M_inv, pokemon)
'''