-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Labels
bugSomething isn't workingSomething isn't working
Description
The CPU kernels do not build correctly. Please check the installation of braintaichi.
on Windows 11
system with Python 3.11
.
The code
import numpy as np import time import brainpy as bp import brainpy.math as bm import matplotlib.pyplot as plt from typing import Union, Sequence, Callable, Optional from brainpy import math as bm from brainpy._src.context import share from brainpy._src.initialize import parameter from brainpy._src.dyn import _docs from brainpy._src.dyn.base import SynDyn from brainpy._src.integrators.joint_eq import JointEq from brainpy._src.integrators.ode.generic import odeint from brainpy._src.mixin import AlignPost, ReturnInfo from brainpy.types import ArrayType import types import os import cv2 from PIL import Image from tqdm import tqdm from concurrent.futures import ProcessPoolExecutor bm.set_platform('cpu') def mkdir(fn): '''创建文件夹及中间文件夹''' os.makedirs(fn, exist_ok=True) def fig_to_video(fig_paths, filename, frame_rate=24, delete_figs=False, formats=None): ''' 将图片合成视频或gif ''' if formats is None: formats = ['mp4', 'gif'] # 创建文件夹 mkdir(os.path.dirname(filename)) valid_fig_paths = fig_paths.copy() # Ensure there are valid figs to process if not valid_fig_paths: print("No valid figs to process.") return if 'mp4' in formats: # MP4 Video output video_filename = f"{filename}.mp4" frame = cv2.imread(valid_fig_paths[0]) height, width, layers = frame.shape video_size = (width, height) fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(video_filename, fourcc, frame_rate, video_size) for fig_path in valid_fig_paths: fig = cv2.imread(fig_path) if fig.shape[1] != video_size[0] or fig.shape[0] != video_size[1]: fig = cv2.resize( fig, video_size, interpolation=cv2.INTER_LANCZOS4) out.write(fig) out.release() if 'gif' in formats: # GIF output gif_filename = f"{filename}.gif" figs = [Image.open(fig_path) for fig_path in valid_fig_paths] resized_figs = [fig.resize( (figs[0].width, figs[0].height), Image.LANCZOS) for fig in figs] resized_figs[0].save(gif_filename, save_all=True, append_images=resized_figs[1:], duration=1000/frame_rate, loop=0) if delete_figs: # Delete figs after processing all formats for fig_path in valid_fig_paths: os.remove(fig_path) print("All valid figs deleted.") cv2.destroyAllWindows() def split_list(lst, n): ''' 将列表尽量均等地分割为n个子列表。 参数: - lst: list 要分割的列表。 - n: int 子列表的数量。 返回: list: 包含n个子列表的列表。 ''' # 计算每个子列表的长度 length = len(lst) size = length // n remainder = length % n # 创建子列表 divided_list = [] start = 0 for i in range(n): # 确定子列表的长度 sublist_size = size + 1 if i < remainder else size # 添加子列表到结果列表中 divided_list.append(lst[start:start + sublist_size]) # 更新下一个子列表的起始位置 start += sublist_size return divided_list def flatten_list(input_list, level=None): ''' 展开嵌套列表(要求每一层都是列表) 参数: - input_list: 嵌套列表 - level: 展开的层级数,如果为1,则从外往内展开一层,如果为2,则从外往内展开两层,以此类推;如果为None,则展开所有层级 ''' if level is None: level = float('inf') # 默认情况下展开所有层 def flatten_recursive(lst, curr_level): flattened = [] for item in lst: if isinstance(item, list) and curr_level < level: flattened.extend(flatten_recursive(item, curr_level + 1)) else: flattened.append(item) return flattened return flatten_recursive(input_list, 0) def multi_process(process_num, func, args_list=None, kwargs_list=None, func_name=''): ''' 多进程并行处理函数 参数: - process_num: int, 并行处理的进程数(由于multi_process状态下,代码错误的提示难以看出错误位置,在测试时可以先把process_num设置为1,这时会按照正常默认方式运行和报错) - func: function, 要并行处理的函数 - args_list: list, 函数的位置参数列表 - kwargs_list: list, 函数的关键字参数列表 - func_name: str, 函数的名称(也可以输入任务的名称等需要显示的信息) 注意: 假如args_list和kwargs_list的长度等于1,则会将其扩展到process_num 假如args_list = [(1), (2)]这样的写法是不对的,至少要让里面成为元组,即args_list = [(1,), (2,)] 假如已经在multi_process中,继续使用multi_process会自动转为单进程运行(此时args_list和kwargs_list会被flatten) ''' if args_list is None: args_list = [()] if kwargs_list is None: kwargs_list = [{}] for i, args in enumerate(args_list): if args is None: args_list[i] = () for i, kwargs in enumerate(kwargs_list): if kwargs is None: kwargs_list[i] = {} if len(args_list) != process_num: if len(args_list) == 1: args_list = args_list * process_num elif process_num == 1: args_list = flatten_list(args_list, level=1) else: raise ValueError("The length of args_list must be equal to process_num or 1.") if len(kwargs_list) != process_num: if len(kwargs_list) == 1: kwargs_list = kwargs_list * process_num elif process_num == 1: kwargs_list = flatten_list(kwargs_list, level=1) else: raise ValueError("The length of kwargs_list must be equal to process_num or 1.") if process_num != 1: results = [] # 使用 ProcessPoolExecutor 进行多进程处理 with ProcessPoolExecutor(max_workers=process_num) as executor: # 提交任务 futures = [executor.submit(func, *args, **kwargs) for args, kwargs in zip(args_list, kwargs_list)] # 等待所有future对象按照提交的顺序完成,并收集结果 for future in futures: try: # 这里按照futures的顺序获取结果,保证结果的顺序与提交顺序相同 results.append(future.result()) except Exception as e: results.append(None) print(f"An error occurred: {e}") return results elif process_num == 1: return [func(*args, **kwargs) for args, kwargs in zip(args_list, kwargs_list)] def part_list_for(func, for_list, for_idx_name, *args, **kwargs): results = [] for i in for_list: results.append(func(*args, **{**kwargs, for_idx_name: i})) return results def multi_process_list_for(process_num, func, args=None, kwargs=None, for_list=None, for_idx_name='i', func_name=''): ''' 多进程并行处理for循环,for循环形式为for i in for_list 参数: - process_num: int, 并行处理的进程数 - func: function, 要并行处理的函数 - args: 函数的位置参数(不推荐,因为idx在func中的位置不确定) - kwargs: 函数的关键字参数 - func_name: str, 函数的名称(也可以输入任务的名称等需要显示的信息) 注意: 只有当for循环每个之间独立时才能使用这个函数 如果需要使用items()方法,请使用multi_process_items_for;如果需要使用enumerate()方法,请使用multi_process_enumerate_for;此处尚未支持zip()方法,但是zip也可以通过普通for循环实现 ''' if args is None: args = () if kwargs is None: kwargs = {} for_list = list(for_list) # 防止for_list是生成器,比如range(10) divided_list = split_list(for_list, process_num) args_list = [(func, divided, for_idx_name)+args for divided in divided_list] kwargs_list = [kwargs] * process_num return flatten_list(multi_process(process_num, part_list_for, args_list, kwargs_list, func_name), level=1) class Func(SynDyn, AlignPost): def __init__( self, size: Union[int, Sequence[int]], keep_size: bool = False, sharding: Optional[Sequence[str]] = None, name: Optional[str] = None, mode: Optional[bm.Mode] = None, func: Optional[Callable] = None, ): super().__init__(name=name, mode=mode, size=size, keep_size=keep_size, sharding=sharding) # parameters self.func = func # function self._current = None self.reset_state(self.mode) def reset_state(self, batch_or_mode=None, **kwargs): self.g = self.init_variable(bm.zeros, batch_or_mode) def update(self, x=None): self.g.value = bm.ones_like(self.g.value) * self.func(share['t']) return self.g.value def add_current(self, x): self.g.value += x def return_info(self): return self.g class ExponentialCOBA(bp.Projection): def __init__(self, pre, post, delay, tau, E, comm): super().__init__() self.proj = bp.dyn.FullProjAlignPost( pre=pre, delay=delay, comm=comm, syn=bp.dyn.Expon(size=post.num, tau=tau),# Exponential synapse out=bp.dyn.COBA(E=E), post=post ) class NormalizedDualExponV2(bp.dyn.DualExponV2): ''' 调整A的默认值(https://brainpy.readthedocs.io/en/latest/apis/generated/brainpy.dyn.DualExponV2.html),使得整个kernel积分为1 注意,如果想要获取g的话,要使用这样的语法: 定义syn self.syn = bf.NormalizedDualExponCUBA(self.pre, self.post, delay=None, comm=bp.dnn.CSRLinear(bp.conn.FixedProb(1., pre=self.pre.num, post=self.post.num), 1.), tau_rise=2., tau_decay=20.) 拿到syn的两个g和a (self.syn.proj.refs['syn'].g_decay - self.syn.proj.refs['syn'].g_rise) * self.syn.proj.refs['syn'].a 相比之下,NormailzedExponCUBA的g可以直接拿到 ''' def __init__( self, size: Union[int, Sequence[int]], keep_size: bool = False, sharding: Optional[Sequence[str]] = None, method: str = 'exp_auto', name: Optional[str] = None, mode: Optional[bm.Mode] = None, # synapse parameters tau_decay: Union[float, ArrayType, Callable] = 10.0, tau_rise: Union[float, ArrayType, Callable] = 1., A: Optional[Union[float, ArrayType, Callable]] = None, ): super().__init__(name=name, mode=mode, size=size, keep_size=keep_size, sharding=sharding) def _format_dual_exp_A(self, A): A = parameter(A, sizes=self.varshape, allow_none=True, sharding=self.sharding) if A is None: A = 1 / (self.tau_decay - self.tau_rise) return A # parameters self.tau_rise = self.init_param(tau_rise) self.tau_decay = self.init_param(tau_decay) self.a = _format_dual_exp_A(self, A) # integrator self.integral = odeint(lambda g, t, tau: -g / tau, method=method) self.reset_state(self.mode) class NormalizedDualExponCOBA(bp.Projection): def __init__(self, pre, post, delay, comm, tau_rise, tau_decay, E, out_label=None): super().__init__() self.proj = bp.dyn.FullProjAlignPostMg( pre=pre, delay=delay, comm=comm, syn=NormalizedDualExponV2.desc(post.num, tau_rise=tau_rise, tau_decay=tau_decay), out=bp.dyn.COBA.desc(E), post=post, out_label=out_label ) class FuncCUBA(bp.Projection): # CUBA: current-based synapse def __init__(self, pre, post, delay, func, out_label=None): super().__init__() self.proj = bp.dyn.FullProjAlignPost( pre=pre, delay=delay, comm=bp.dnn.AllToAll(pre.num, post.num, 1.), syn=Func(size=post.num, func=func), out=bp.dyn.CUBA(), post=post, out_label=out_label ) class FuncCOBA(bp.Projection): # COBA: conductance-based synapse def __init__(self, pre, post, delay, func, E, out_label=None): super().__init__() self.proj = bp.dyn.FullProjAlignPost( pre=pre, delay=delay, comm=bp.dnn.AllToAll(pre.num, post.num, 1.), syn=Func(size=post.num, func=func), out=bp.dyn.COBA(E=E), post=post, out_label=out_label ) def ij_conn(pre, post, pre_size, post_size): ''' 利用brainpy的bp.conn.IJConn生成conn ''' conn = bp.conn.IJConn(i=pre, j=post) conn = conn(pre_size=pre_size, post_size=post_size) return conn def ij_comm(pre, post, pre_size, post_size, weight): ''' 利用brainpy的bp.conn.IJConn和bp.dnn.EventCSRLinear生成comm ''' conn = ij_conn(pre, post, pre_size, post_size) return bp.dnn.EventCSRLinear(conn, weight) class EINet(bp.DynamicalSystem): def __init__(self, grid_num, grid_distance, E_conn_grid_num, I_conn_grid_num, wE, wI, sigmaE, E_inp_func, I_inp_func): super().__init__() self.location, self.conn_weight = generate_conn_and_weight(grid_num, grid_distance, E_conn_grid_num, I_conn_grid_num, wE, wI, sigmaE) ne = self.location['E_num'] ni = self.location['I_num'] # 神经元 self.E = bp.dyn.ExpIFRef(ne, V_rest=-70., V_reset=-70., V_th=-40., V_T=-60.6250, delta_T=6.5625, tau=20., R=1., tau_ref=5., V_initializer=bp.init.Normal(-55., 10.)) self.I = bp.dyn.ExpIFRef(ni, V_rest=-70., V_reset=-70., V_th=-40., V_T=-60.6250, delta_T=6.5625, tau=20., R=1., tau_ref=5., V_initializer=bp.init.Normal(-55., 10.)) self.E_inp = bp.dyn.InputGroup(size=1) # placeholder self.I_inp = bp.dyn.InputGroup(size=1) # placeholder # 连接 neuron_comm = get_comm(self.location, self.conn_weight) self.E2E = NormalizedDualExponCOBA(pre=self.E, post=self.E, delay=0., comm=neuron_comm['E2E_comm'], tau_rise=0.3, tau_decay=2., E=0., out_label='E') self.E2I = NormalizedDualExponCOBA(pre=self.E, post=self.I, delay=0., comm=neuron_comm['E2I_comm'], tau_rise=0.3, tau_decay=2., E=0., out_label='E') self.I2E = NormalizedDualExponCOBA(pre=self.I, post=self.E, delay=0., comm=neuron_comm['I2E_comm'], tau_rise=0.3, tau_decay=3., E=-80., out_label='I') self.I2I = NormalizedDualExponCOBA(pre=self.I, post=self.I, delay=0., comm=neuron_comm['I2I_comm'], tau_rise=0.3, tau_decay=3., E=-80., out_label='I') # 额外的输入 addtion_func = lambda t: 0.001 * 20 self.additionE2E = FuncCOBA(pre=self.E_inp, post=self.E, delay=0., func=addtion_func, E=0., out_label='E') self.additionE2I = FuncCOBA(pre=self.I_inp, post=self.I, delay=0., func=addtion_func, E=0., out_label='E') self.additionI2E = FuncCOBA(pre=self.E_inp, post=self.E, delay=0., func=addtion_func, E=-80., out_label='I') self.additionI2I = FuncCOBA(pre=self.I_inp, post=self.I, delay=0., func=addtion_func, E=-80., out_label='I') # 输入 self.E_inp2E = FuncCUBA(pre=self.E_inp, post=self.E, delay=0., func=E_inp_func, out_label='input') self.I_inp2I = FuncCUBA(pre=self.I_inp, post=self.I, delay=0., func=I_inp_func, out_label='input') # Poisson输入 self.Poisson_inp_for_E = bp.dyn.PoissonGroup(size=ne, freqs=10.) self.Poisson_inp_for_E2E = NormalizedDualExponCOBA(pre=self.Poisson_inp_for_E, post=self.E, delay=0., comm=bp.dnn.OneToOne(ne, ne, 0.2), tau_rise=0.3, tau_decay=2., E=0., out_label='input') self.Poisson_inp_for_I = bp.dyn.PoissonGroup(size=ni, freqs=10.) self.Poisson_inp_for_I2I = NormalizedDualExponCOBA(pre=self.Poisson_inp_for_I, post=self.I, delay=0., comm=bp.dnn.OneToOne(ni, ni, 0.2), tau_rise=0.3, tau_decay=2., E=0., out_label='input') def update(self): self.E2E() self.E2I() self.I2E() self.I2I() self.additionE2E() self.additionE2I() self.additionI2E() self.additionI2I() self.E_inp2E() self.I_inp2I() self.Poisson_inp_for_E2E() self.Poisson_inp_for_I2I() self.E() self.I() def generate_location(grid_num, grid_distance): '''生成grid_num*grid_num个位置,grid_distance为两个相邻位置之间的距离,在mesh上,E_neuron的间隔为1,I_neuron的间隔为2''' grid_loc_x = np.arange(0, grid_num) * grid_distance grid_loc_y = np.arange(0, grid_num) * grid_distance grid_idx_E = np.arange(0, grid_num) grid_idx_I = grid_idx_E[::2] E_x_mesh, E_y_mesh = np.meshgrid(grid_loc_x[grid_idx_E], grid_loc_y[grid_idx_E]) E_i_mesh, E_j_mesh = np.meshgrid(grid_idx_E, grid_idx_E) I_x_mesh, I_y_mesh = np.meshgrid(grid_loc_x[grid_idx_I], grid_loc_y[grid_idx_I]) I_i_mesh, I_j_mesh = np.meshgrid(grid_idx_I, grid_idx_I) return E_x_mesh, E_y_mesh, E_i_mesh, E_j_mesh, I_x_mesh, I_y_mesh, I_i_mesh, I_j_mesh def part_generate_conn_and_weight(idx, location, pre_group, pre_step, post_group, post_step, E_conn_grid_num, I_conn_grid_num, wE, wI, sigmaE): # 确定使用的conn_grid_num if pre_group == 'E': conn_grid_num = E_conn_grid_num elif pre_group == 'I': conn_grid_num = I_conn_grid_num conn_weight_pre_idx = [] conn_weight_post_idx = [] conn_weight_weight = [] # pre_i, pre_j 是相对于grid的索引, post_i, post_j 是相对于grid的索引, 这意味着E,I的间隔不同; 为了使用brainpy,需要转换为连续的索引 pre_i, pre_j = location[f'{pre_group}_i_mesh'].flatten()[idx], location[f'{pre_group}_j_mesh'].flatten()[idx] # 只在一个正方形的范围内判断,以减少计算量(这边需要保证好仍然是step的整数倍) square_range = conn_grid_num + post_step # 略微扩大一点范围 post_i_start = int((pre_i - square_range) // post_step * post_step) post_i_end = int((pre_i + square_range) // post_step * post_step) post_j_start = int((pre_j - square_range) // post_step * post_step) post_j_end = int((pre_j + square_range) // post_step * post_step) post_i_range = np.arange(post_i_start, post_i_end, step=post_step) post_j_range = np.arange(post_j_start, post_j_end, step=post_step) # 应用周期性边界条件 post_i_range = post_i_range % grid_num post_j_range = post_j_range % grid_num # 可能会有重复,去除 post_i_range = np.unique(post_i_range) post_j_range = np.unique(post_j_range) mesh_post_i, mesh_post_j = np.meshgrid(post_i_range, post_j_range) for post_i, post_j in zip(mesh_post_i.flatten(), mesh_post_j.flatten()): # 应用周期性边界条件 i_distance = np.min([np.abs((pre_i - post_i)), np.abs((pre_i - post_i + grid_num)), np.abs((pre_i - post_i - grid_num))]) j_distance = np.min([np.abs((pre_j - post_j)), np.abs((pre_j - post_j + grid_num)), np.abs((pre_j - post_j - grid_num))]) # 计算l2距离 l2_distance = np.sqrt(i_distance ** 2 + j_distance ** 2) if l2_distance <= conn_grid_num: conn_weight_pre_idx.append(int(round((pre_i * location[f'{pre_group}_grid_num'] + pre_j)/pre_step))) # 除以间隔来得到正确的索引 conn_weight_post_idx.append(int(round((post_i * location[f'{post_group}_grid_num'] + post_j)/post_step))) # 除以间隔来得到正确的索引 if pre_group == 'E': conn_weight_weight.append(wE * np.exp(- l2_distance**2 / sigmaE) * np.abs(np.random.normal(1, 0.4))) elif pre_group == 'I': conn_weight_weight.append(wI * np.abs(np.random.normal(1, 0.4))) return conn_weight_pre_idx, conn_weight_post_idx, conn_weight_weight def generate_conn_and_weight(grid_num, grid_distance, E_conn_grid_num, I_conn_grid_num, wE, wI, sigmaE): '''生成连接索引和连接权重''' location = {} # 生成位置 location['E_x_mesh'], location['E_y_mesh'], location['E_i_mesh'], location['E_j_mesh'], location['I_x_mesh'], location['I_y_mesh'], location['I_i_mesh'], location['I_j_mesh'] = generate_location(grid_num, grid_distance) location['E_num'] = location['E_x_mesh'].size location['I_num'] = location['I_x_mesh'].size location['E_step'] = 1 location['I_step'] = 2 location['E_grid_num'] = grid_num // location['E_step'] location['I_grid_num'] = grid_num // location['I_step'] # 生成连接索引和连接权重 conn_weight = {} for pre_group in ['E', 'I']: for post_group in ['E', 'I']: print(f'{pre_group}2{post_group}') conn_weight[f'{pre_group}2{post_group}_pre_idx'] = [] conn_weight[f'{pre_group}2{post_group}_post_idx'] = [] conn_weight[f'{pre_group}2{post_group}_weight'] = [] pre_step = location[f'{pre_group}_step'] post_step = location[f'{post_group}_step'] # multi_process加速 r = multi_process_list_for(process_num=process_num, func=part_generate_conn_and_weight, kwargs={'location': location, 'pre_group': pre_group, 'pre_step': pre_step, 'post_group': post_group, 'post_step': post_step, 'E_conn_grid_num': E_conn_grid_num, 'I_conn_grid_num': I_conn_grid_num, 'wE': wE, 'wI': wI, 'sigmaE': sigmaE}, for_list=np.arange(location[f'{pre_group}_i_mesh'].size), for_idx_name='idx') # 整理结果 for sub_r in r: conn_weight[f'{pre_group}2{post_group}_pre_idx'].extend(sub_r[0]) conn_weight[f'{pre_group}2{post_group}_post_idx'].extend(sub_r[1]) conn_weight[f'{pre_group}2{post_group}_weight'].extend(sub_r[2]) # 转换为np.array conn_weight[f'{pre_group}2{post_group}_pre_idx'] = np.array(conn_weight[f'{pre_group}2{post_group}_pre_idx']) conn_weight[f'{pre_group}2{post_group}_post_idx'] = np.array(conn_weight[f'{pre_group}2{post_group}_post_idx']) conn_weight[f'{pre_group}2{post_group}_weight'] = np.array(conn_weight[f'{pre_group}2{post_group}_weight']) # 判断是否有重复(重复指的是一个二元组出现多次) pre_post_idx = np.stack([conn_weight[f'{pre_group}2{post_group}_pre_idx'], conn_weight[f'{pre_group}2{post_group}_post_idx']], axis=1) unique_pre_post_idx, unique_idx = np.unique(pre_post_idx, axis=0, return_index=True) if len(unique_idx) != len(pre_post_idx): print('有重复的连接') return location, conn_weight def get_comm(location, conn_weight): '''生成comm''' neuron_comm = {} for pre_group in ['E', 'I']: for post_group in ['E', 'I']: print(f'{pre_group}2{post_group}_comm') neuron_comm[f'{pre_group}2{post_group}_comm'] = ij_comm(pre=conn_weight[f'{pre_group}2{post_group}_pre_idx'], post=conn_weight[f'{pre_group}2{post_group}_post_idx'], pre_size=location[f'{pre_group}_num'], post_size=location[f'{post_group}_num'], weight=conn_weight[f'{pre_group}2{post_group}_weight']) return neuron_comm def set_xylim(ax, grid_num, grid_distance): xlim = [0, (grid_num-1) * grid_distance] ylim = [0, (grid_num-1) * grid_distance] ax.set_xlim(xlim) ax.set_ylim(ylim) ax.set_aspect('equal') def visualize_V_one_step(i, basedir, vmin, vmax, E_V, I_V, ts, location, grid_num, grid_distance, s): fig_path = os.path.join(basedir, 'V', f'{i}.png') fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True) ax = axes[0] sc = ax.scatter(location['E_x_mesh'].flatten(), location['E_y_mesh'].flatten(), c=E_V[i], cmap=plt.cm.jet, vmin=vmin, vmax=vmax, s=s, clip_on=False) cbar = plt.colorbar(sc, ax=ax) ax.set_title('E') set_xylim(ax, grid_num, grid_distance) ax = axes[1] sc = ax.scatter(location['I_x_mesh'].flatten(), location['I_y_mesh'].flatten(), c=I_V[i], cmap=plt.cm.jet, vmin=vmin, vmax=vmax, s=s, clip_on=False) cbar = plt.colorbar(sc, ax=ax) ax.set_title('I') set_xylim(ax, grid_num, grid_distance) fig.suptitle(f't={ts[i]:.3f}') fig.savefig(fig_path) plt.close(fig) return fig_path class SNN_analyzer: def __init__(self, grid_num, grid_distance, E_conn_grid_num, I_conn_grid_num, wE, wI, sigmaE, E_inp_func, I_inp_func, time_period): self.grid_num = grid_num self.grid_distance = grid_distance self.E_conn_grid_num = E_conn_grid_num self.I_conn_grid_num = I_conn_grid_num self.wE = wE self.wI = wI self.sigmaE = sigmaE self.E_inp_func = E_inp_func self.I_inp_func = I_inp_func self.net = EINet(grid_num=grid_num, grid_distance=grid_distance, E_conn_grid_num=E_conn_grid_num, I_conn_grid_num=I_conn_grid_num, wE=wE, wI=wI, sigmaE=sigmaE, E_inp_func=E_inp_func, I_inp_func=I_inp_func) self.location = self.net.location self.conn_weight = self.net.conn_weight monitors = {'E.spike': self.net.E.spike, 'I.spike': self.net.I.spike, 'E.V': self.net.E.V, 'I.V': self.net.I.V} monitors['E.E_current'] = lambda: self.net.E.sum_current_inputs(self.net.E.V, label='E') monitors['E.I_current'] = lambda: self.net.E.sum_current_inputs(self.net.E.V, label='I') monitors['I.E_current'] = lambda: self.net.I.sum_current_inputs(self.net.I.V, label='E') monitors['I.I_current'] = lambda: self.net.I.sum_current_inputs(self.net.I.V, label='I') monitors['E.input_current'] = lambda: self.net.E.sum_current_inputs(self.net.E.V, label='input') monitors['I.input_current'] = lambda: self.net.I.sum_current_inputs(self.net.I.V, label='input') monitors['E.E2E_g'] = lambda: (self.net.E2E.proj.refs['syn'].g_decay - self.net.E2E.proj.refs['syn'].g_rise) * self.net.E2E.proj.refs['syn'].a monitors['E.E2I_g'] = lambda: (self.net.E2I.proj.refs['syn'].g_decay - self.net.E2I.proj.refs['syn'].g_rise) * self.net.E2I.proj.refs['syn'].a monitors['I.I2E_g'] = lambda: (self.net.I2E.proj.refs['syn'].g_decay - self.net.I2E.proj.refs['syn'].g_rise) * self.net.I2E.proj.refs['syn'].a monitors['I.I2I_g'] = lambda: (self.net.I2I.proj.refs['syn'].g_decay - self.net.I2I.proj.refs['syn'].g_rise) * self.net.I2I.proj.refs['syn'].a self.runner = bp.DSRunner(self.net, monitors=monitors) self.runner.run(duration=time_period) self.indices = np.arange(int(time_period / bm.get_dt())) self.ts = self.indices * bm.get_dt() self.E_spike = self.runner.mon['E.spike'] self.I_spike = self.runner.mon['I.spike'] self.E_V = self.runner.mon['E.V'] self.I_V = self.runner.mon['I.V'] basedir = '../../results' current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) self.basedir = os.path.join(basedir, current_time) mkdir(self.basedir) self.s = np.pi * (100 / self.grid_num)**2 def visualize_V(self, start=0, stop=None, step=10, frame_rate=5, delete_figs=True): if stop is None: stop = np.min([len(self.ts), 1000]) s = self.s mkdir(os.path.join(self.basedir, 'V')) fig_paths = multi_process_list_for(process_num=process_num, func=visualize_V_one_step, kwargs={'basedir': self.basedir, 'vmin': self.net.E.V_rest, 'vmax': self.net.E.V_th, 'E_V': self.E_V, 'I_V': self.I_V, 'ts': self.ts, 'location': self.location, 'grid_num': self.grid_num, 'grid_distance': self.grid_distance, 's': s}, for_list=np.arange(start, stop, step), for_idx_name='i') fig_to_video(fig_paths, os.path.join(self.basedir, 'V', 'V_video'), frame_rate=frame_rate, delete_figs=delete_figs) def visualize_V_one_neuron(self, neuron_group, row_idx, col_idx): if neuron_group == 'E': neuron_idx = row_idx * self.grid_num + col_idx elif neuron_group == 'I': neuron_idx = row_idx * self.grid_num//2 + col_idx mkdir(os.path.join(self.basedir, 'V_one_neuron')) fig, ax = plt.subplots(1, 1, figsize=(6, 6)) s = self.s if neuron_group == 'E': V = self.E_V elif neuron_group == 'I': V = self.I_V ax.scatter(self.indices, V[:, neuron_idx], s=s, clip_on=False) ax.set_title(f'{neuron_group} row {row_idx} col {col_idx}') ax.set_xlabel('t') ax.set_ylabel('V') fig.savefig(os.path.join(self.basedir, 'V_one_neuron', f'{neuron_group}_row_{row_idx}_col_{col_idx}.png')) plt.close(fig) def visualize_spike(self): mkdir(os.path.join(self.basedir, 'spike')) fig, ax = plt.subplots(2, 1, figsize=(6, 6)) bp.visualize.raster_plot(self.ts, self.E_spike, ax=ax[0]) bp.visualize.raster_plot(self.ts, self.I_spike, ax=ax[1]) ax[0].set_title('E') ax[1].set_title('I') fig.suptitle('Spike') fig.savefig(os.path.join(self.basedir, 'spike', 'spike.png')) plt.close(fig) def visualize_current(self, E_neuron_idx=None, I_neuron_idx=None): if isinstance(E_neuron_idx, int): E_neuron_idx = (E_neuron_idx, ) if isinstance(I_neuron_idx, int): I_neuron_idx = (I_neuron_idx, ) if E_neuron_idx is None: E_neuron_idx = slice(None) if I_neuron_idx is None: I_neuron_idx = slice(None) mkdir(os.path.join(self.basedir, 'current')) fig, axes = plt.subplots(3, 2, figsize=(12, 6)) ax = axes[0, 0] ax.plot(self.ts, np.mean(self.runner.mon['E.E_current'][:, E_neuron_idx], axis=1), label='E') ax = axes[1, 0] ax.plot(self.ts, np.mean(self.runner.mon['E.I_current'][:, E_neuron_idx], axis=1), label='I') ax = axes[2, 0] ax.plot(self.ts, np.mean(self.runner.mon['E.input_current'][:, E_neuron_idx], axis=1), label='input') for ax in axes[:, 0]: ax.set_title('E') ax.legend() ax = axes[0, 1] ax.plot(self.ts, np.mean(self.runner.mon['I.E_current'][:, I_neuron_idx], axis=1), label='E') ax = axes[1, 1] ax.plot(self.ts, np.mean(self.runner.mon['I.I_current'][:, I_neuron_idx], axis=1), label='I') ax = axes[2, 1] ax.plot(self.ts, np.mean(self.runner.mon['I.input_current'][:, I_neuron_idx], axis=1), label='input') for ax in axes[:, 1]: ax.set_title('I') ax.legend() fig.suptitle('Current') fig.savefig(os.path.join(self.basedir, 'current', f'current_{E_neuron_idx}_{I_neuron_idx}.png')) plt.close(fig) if __name__ == '__main__': bm.set_dt(0.1) process_num = 1 grid_num = 100 grid_distance = 6.1 * 10**(-3) # mm # 小包的case E_conn_grid_num = 10 I_conn_grid_num = 40 wE = 20. * 0.2235 * 0.001 / np.sqrt(0.4) * 80 wI = 20. * 0.0578 * 0.001 / np.sqrt(0.4) * 40 # # 连成一片的case E_conn_grid_num = 36 I_conn_grid_num = 12 wE = 20. * 0.2235 * 0.001 / np.sqrt(0.4) * 50 wI = 20. * 0.0578 * 0.001 / np.sqrt(0.4) * 120 sigmaE = 18. time_period = 1000. # ms def E_inp_func(t): r = np.zeros((grid_num, grid_num)) return r.flatten() I_inp_func = lambda x: 0. snn_analyzer = SNN_analyzer(grid_num, grid_distance, E_conn_grid_num, I_conn_grid_num, wE, wI, sigmaE, E_inp_func, I_inp_func, time_period) snn_analyzer.visualize_spike() snn_analyzer.visualize_V_one_neuron('E', 2, 17) snn_analyzer.visualize_V_one_neuron('E', 2, 18) snn_analyzer.visualize_V_one_neuron('E', 2, 2) snn_analyzer.visualize_V_one_neuron('E', 18, 18) snn_analyzer.visualize_V_one_neuron('E', 3, 18) snn_analyzer.visualize_current() snn_analyzer.visualize_current(E_neuron_idx=0, I_neuron_idx=0) snn_analyzer.visualize_current(E_neuron_idx=grid_num//2 * grid_num + grid_num//2, I_neuron_idx=grid_num//2//2 * grid_num//2 + grid_num//2//2) snn_analyzer.visualize_V(start=500, stop=4000, step=25, frame_rate=15, delete_figs=True)
Traceback (most recent call last):
Traceback (most recent call last): File "C:\Program Files\Python311\Lib\site-packages\jax\_src\interpreters\mlir.py", line 2150, in _lower_jaxpr_to_fun_cached func_op = ctx.cached_primitive_lowerings[key] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^ KeyError: (None, let _where = { lambda ; a:bool[10000] b:f32[] c:f32[10000]. let d:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b e:f32[10000] = broadcast_in_dim[ broadcast_dimensions=() shape=(10000,) sharding=None ] d f:f32[10000] = select_n a c e in (f,) } in let _where1 = { lambda ; g:bool[2500] h:f32[] i:f32[2500]. let j:f32[] = convert_element_type[new_dtype=float32 weak_type=False] h k:f32[2500] = broadcast_in_dim[ broadcast_dimensions=() shape=(2500,) sharding=None ] j l:f32[2500] = select_n g i k in (l,) } in { lambda ; m:f32[40530000] n:i32[40530000] o:i32[10001] p:f32[10132500] q:i32[10132500] r:i32[10001] s:f32[1102500] t:i32[1102500] u:i32[2501] v:f32[282500] w:i32[282500] x:i32[2501] y:f32[10000] z:f32[2500] ba:f32[2500] bb:f32[2500] bc:bool[10000] bd:f32[2500] be:f32[10000] bf:bool[2500] bg:f32[10000] bh:f32[2500] bi:f32[10000] bj:f32[10000] bk:f32[10000] bl:bool[10000] bm:f32[1] bn:f32[2500] bo:f32[10000] bp:f32[10000] bq:f32[1] br:f32[2500] bs:f32[10000] bt:f32[10000] bu:f32[2500] bv:f32[2500] bw:bool[2500] bx:f32[10000] by:f32[10000] bz:f32[2500] ca:f32[2500] cb:i32[]. let cc:f32[] = convert_element_type[new_dtype=float32 weak_type=True] cb cd:f32[] = mul cc 0.1 ce:f32[] = add 0.0 cd cf:f32[10000] = braintaichi_custom_op_43[ float_as_event=True outs=(ShapedArray(float32[10000]),) shape=(10000, 10000) transpose=True ] m n o bc cg:f32[10000] = add bo cf ch:f32[10000] = add bp cf ci:f32[2500] = braintaichi_custom_op_43[ float_as_event=True outs=(ShapedArray(float32[2500]),) shape=(10000, 2500) transpose=True ] p q r bc cj:f32[2500] = add bh ci ck:f32[2500] = add br ci cl:f32[10000] = braintaichi_custom_op_43[ float_as_event=True outs=(ShapedArray(float32[10000]),) shape=(2500, 10000) transpose=True ] s t u bf cm:f32[10000] = add bs cl cn:f32[10000] = add bt cl co:f32[2500] = braintaichi_custom_op_43[ float_as_event=True outs=(ShapedArray(float32[2500]),) shape=(2500, 2500) transpose=True ] v w x bf cp:f32[2500] = add bu co cq:f32[2500] = add bv co cr:f32[10000] = broadcast_in_dim[ broadcast_dimensions=() shape=(10000,) sharding=None ] 1.0 cs:f32[10000] = mul cr 0.019999999552965164 ct:f32[2500] = broadcast_in_dim[ broadcast_dimensions=() shape=(2500,) sharding=None ] 1.0 cu:f32[2500] = mul ct 0.019999999552965164 cv:f32[10000] = broadcast_in_dim[ broadcast_dimensions=() shape=(10000,) sharding=None ] 1.0 cw:f32[10000] = mul cv 0.019999999552965164 cx:f32[2500] = broadcast_in_dim[ broadcast_dimensions=() shape=(2500,) sharding=None ] 1.0 cy:f32[2500] = mul cx 0.019999999552965164 cz:f32[10000] = broadcast_in_dim[ broadcast_dimensions=() shape=(10000,) sharding=None ] 1.0 da:f32[10000] = device_put[ copy_semantics=[<CopySemantics.ALIAS: 1>] devices=[None] srcs=[None] ] y db:f32[10000] = mul cz da dc:f32[2500] = broadcast_in_dim[ broadcast_dimensions=() shape=(2500,) sharding=None ] 1.0 dd:f32[2500] = mul dc 0.0 de:i32[10000] = convert_element_type[new_dtype=int32 weak_type=True] bl df:i32[10000] = mul de 10000 dg:f32[10000] = convert_element_type[new_dtype=float32 weak_type=False] df dh:f32[10000] = add bx dg di:f32[10000] = convert_element_type[new_dtype=float32 weak_type=False] df dj:f32[10000] = add by di dk:i32[2500] = convert_element_type[new_dtype=int32 weak_type=True] bw dl:i32[2500] = mul dk 2500 dm:f32[2500] = convert_element_type[new_dtype=float32 weak_type=False] dl dn:f32[2500] = add bz dm do:f32[2500] = convert_element_type[new_dtype=float32 weak_type=False] dl dp:f32[2500] = add ca do dq:f32[10000] = neg cg dr:f32[10000] = div dq 0.30000001192092896 ds:f32[10000] = broadcast_in_dim[ broadcast_dimensions=() shape=(10000,) sharding=None ] 1.0 dt:f32[10000] = div ds 0.30000001192092896 du:f32[10000] = neg dt dv:f32[10000] = mul 0.10000000149011612 du dw:f32[10000] = abs dv dx:bool[10000] = le dw 9.999999747378752e-06 dy:f32[10000] = div dv 2.0 dz:f32[10000] = add 1.0 dy ea:f32[10000] = mul dv dv eb:f32[10000] = div ea 6.0 ec:f32[10000] = add dz eb ed:f32[10000] = exp dv ee:f32[10000] = sub ed 1.0 ef:f32[10000] = div ee dv eg:f32[10000] = select_n dx ef ec eh:f32[10000] = mul 0.10000000149011612 eg ei:f32[10000] = mul eh dr ej:f32[10000] = add cg ei ek:f32[10000] = neg ch el:f32[10000] = div ek 2.0 em:f32[10000] = broadcast_in_dim[ broadcast_dimensions=() shape=(10000,) sharding=None ] 1.0 en:f32[10000] = div em 2.0 eo:f32[10000] = neg en ep:f32[10000] = mul 0.10000000149011612 eo eq:f32[10000] = abs ep er:bool[10000] = le eq 9.999999747378752e-06 es:f32[10000] = div ep 2.0 et:f32[10000] = add 1.0 es eu:f32[10000] = mul ep ep ev:f32[10000] = div eu 6.0 ew:f32[10000] = add et ev ex:f32[10000] = exp ep ey:f32[10000] = sub ex 1.0 ez:f32[10000] = div ey ep fa:f32[10000] = select_n er ez ew fb:f32[10000] = mul 0.10000000149011612 fa fc:f32[10000] = mul fb el fd:f32[10000] = add ch fc fe:f32[10000] = sub fd ej ff:f32[10000] = mul 0.5882353186607361 fe fg:f32[10000] = neg cm fh:f32[10000] = div fg 0.30000001192092896 fi:f32[10000] = broadcast_in_dim[ broadcast_dimensions=() shape=(10000,) sharding=None ] 1.0 fj:f32[10000] = div fi 0.30000001192092896 fk:f32[10000] = neg fj fl:f32[10000] = mul 0.10000000149011612 fk fm:f32[10000] = abs fl fn:bool[10000] = le fm 9.999999747378752e-06 fo:f32[10000] = div fl 2.0 fp:f32[10000] = add 1.0 fo fq:f32[10000] = mul fl fl fr:f32[10000] = div fq 6.0 fs:f32[10000] = add fp fr ft:f32[10000] = exp fl fu:f32[10000] = sub ft 1.0 fv:f32[10000] = div fu fl fw:f32[10000] = select_n fn fv fs fx:f32[10000] = mul 0.10000000149011612 fw fy:f32[10000] = mul fx fh fz:f32[10000] = add cm fy ga:f32[10000] = neg cn gb:f32[10000] = div ga 3.0 gc:f32[10000] = broadcast_in_dim[ broadcast_dimensions=() shape=(10000,) sharding=None ] 1.0 gd:f32[10000] = div gc 3.0 ge:f32[10000] = neg gd gf:f32[10000] = mul 0.10000000149011612 ge gg:f32[10000] = abs gf gh:bool[10000] = le gg 9.999999747378752e-06 gi:f32[10000] = div gf 2.0 gj:f32[10000] = add 1.0 gi gk:f32[10000] = mul gf gf gl:f32[10000] = div gk 6.0 gm:f32[10000] = add gj gl gn:f32[10000] = exp gf go:f32[10000] = sub gn 1.0 gp:f32[10000] = div go gf gq:f32[10000] = select_n gh gp gm gr:f32[10000] = mul 0.10000000149011612 gq gs:f32[10000] = mul gr gb gt:f32[10000] = add cn gs gu:f32[10000] = sub gt fz gv:f32[10000] = mul 0.37037035822868347 gu gw:f32[10000] = neg dh gx:f32[10000] = div gw 0.30000001192092896 gy:f32[10000] = broadcast_in_dim[ broadcast_dimensions=() shape=(10000,) sharding=None ] 1.0 gz:f32[10000] = div gy 0.30000001192092896 ha:f32[10000] = neg gz hb:f32[10000] = mul 0.10000000149011612 ha hc:f32[10000] = abs hb hd:bool[10000] = le hc 9.999999747378752e-06 he:f32[10000] = div hb 2.0 hf:f32[10000] = add 1.0 he hg:f32[10000] = mul hb hb hh:f32[10000] = div hg 6.0 hi:f32[10000] = add hf hh hj:f32[10000] = exp hb hk:f32[10000] = sub hj 1.0 hl:f32[10000] = div hk hb hm:f32[10000] = select_n hd hl hi hn:f32[10000] = mul 0.10000000149011612 hm ho:f32[10000] = mul hn gx hp:f32[10000] = add dh ho hq:f32[10000] = neg dj hr:f32[10000] = div hq 2.0 hs:f32[10000] = broadcast_in_dim[ broadcast_dimensions=() shape=(10000,) sharding=None ] 1.0 ht:f32[10000] = div hs 2.0 hu:f32[10000] = neg ht hv:f32[10000] = mul 0.10000000149011612 hu hw:f32[10000] = abs hv hx:bool[10000] = le hw 9.999999747378752e-06 hy:f32[10000] = div hv 2.0 hz:f32[10000] = add 1.0 hy ia:f32[10000] = mul hv hv ib:f32[10000] = div ia 6.0 ic:f32[10000] = add hz ib id:f32[10000] = exp hv ie:f32[10000] = sub id 1.0 if:f32[10000] = div ie hv ig:f32[10000] = select_n hx if ic ih:f32[10000] = mul 0.10000000149011612 ig ii:f32[10000] = mul ih hr ij:f32[10000] = add dj ii ik:f32[10000] = sub ij hp il:f32[10000] = mul 0.5882353186607361 ik im:f32[10000] = sub 0.0 be in:f32[10000] = mul ff im io:f32[10000] = add in 0.0 ip:f32[10000] = sub -80.0 be iq:f32[10000] = mul gv ip ir:f32[10000] = add io iq is:f32[10000] = sub 0.0 be it:f32[10000] = mul cs is iu:f32[10000] = add ir it iv:f32[10000] = sub -80.0 be iw:f32[10000] = mul cw iv ix:f32[10000] = add iu iw iy:f32[10000] = add ix db iz:f32[10000] = sub 0.0 be ja:f32[10000] = mul il iz jb:f32[10000] = add iy ja jc:f32[10000] = sub be -60.625 jd:f32[10000] = div jc 6.5625 je:f32[10000] = exp jd jf:f32[10000] = mul 6.5625 je jg:f32[10000] = sub be -70.0 jh:f32[10000] = neg jg ji:f32[10000] = add jh jf jj:f32[10000] = mul 1.0 jb jk:f32[10000] = add ji jj jl:f32[10000] = div jk 20.0 jm:f32[10000] = broadcast_in_dim[ broadcast_dimensions=() shape=(10000,) sharding=None ] 1.0 jn:f32[10000] = div jm 20.0 jo:f32[10000] = mul 6.5625 jn jp:f32[10000] = mul jo je jq:f32[10000] = div jp 6.5625 jr:f32[10000] = neg jn js:f32[10000] = add_any jq jr jt:f32[10000] = mul 0.10000000149011612 js ju:f32[10000] = abs jt jv:bool[10000] = le ju 9.999999747378752e-06 jw:f32[10000] = div jt 2.0 jx:f32[10000] = add 1.0 jw jy:f32[10000] = mul jt jt jz:f32[10000] = div jy 6.0 ka:f32[10000] = add jx jz kb:f32[10000] = exp jt kc:f32[10000] = sub kb 1.0 kd:f32[10000] = div kc jt ke:f32[10000] = select_n jv kd ka kf:f32[10000] = mul 0.10000000149011612 ke kg:f32[10000] = mul kf jl kh:f32[10000] = add be kg ki:f32[10000] = add kh 0.0 kj:f32[] = convert_element_type[new_dtype=float32 weak_type=False] ce kk:f32[10000] = sub kj bg kl:bool[10000] = le kk 5.0 km:f32[10000] = pjit[ name=_where jaxpr={ lambda ; kn:bool[10000] ko:f32[10000] kp:f32[10000]. let kq:f32[10000] = select_n kn kp ko in (kq,) } ] kl be ki kr:bool[10000] = ge km -40.0 ks:f32[10000] = pjit[name=_where jaxpr=_where] kr -70.0 km kt:f32[10000] = pjit[name=_where jaxpr=_where] kr ce bg ku:f32[2500] = neg cj kv:f32[2500] = div ku 0.30000001192092896 kw:f32[2500] = broadcast_in_dim[ broadcast_dimensions=() shape=(2500,) sharding=None ] 1.0 kx:f32[2500] = div kw 0.30000001192092896 ky:f32[2500] = neg kx kz:f32[2500] = mul 0.10000000149011612 ky la:f32[2500] = abs kz lb:bool[2500] = le la 9.999999747378752e-06 lc:f32[2500] = div kz 2.0 ld:f32[2500] = add 1.0 lc le:f32[2500] = mul kz kz lf:f32[2500] = div le 6.0 lg:f32[2500] = add ld lf lh:f32[2500] = exp kz li:f32[2500] = sub lh 1.0 lj:f32[2500] = div li kz lk:f32[2500] = select_n lb lj lg ll:f32[2500] = mul 0.10000000149011612 lk lm:f32[2500] = mul ll kv ln:f32[2500] = add cj lm lo:f32[2500] = neg ck lp:f32[2500] = div lo 2.0 lq:f32[2500] = broadcast_in_dim[ broadcast_dimensions=() shape=(2500,) sharding=None ] 1.0 lr:f32[2500] = div lq 2.0 ls:f32[2500] = neg lr lt:f32[2500] = mul 0.10000000149011612 ls lu:f32[2500] = abs lt lv:bool[2500] = le lu 9.999999747378752e-06 lw:f32[2500] = div lt 2.0 lx:f32[2500] = add 1.0 lw ly:f32[2500] = mul lt lt lz:f32[2500] = div ly 6.0 ma:f32[2500] = add lx lz mb:f32[2500] = exp lt mc:f32[2500] = sub mb 1.0 md:f32[2500] = div mc lt me:f32[2500] = select_n lv md ma mf:f32[2500] = mul 0.10000000149011612 me mg:f32[2500] = mul mf lp mh:f32[2500] = add ck mg mi:f32[2500] = sub mh ln mj:f32[2500] = mul 0.5882353186607361 mi mk:f32[2500] = neg cp ml:f32[2500] = div mk 0.30000001192092896 mm:f32[2500] = broadcast_in_dim[ broadcast_dimensions=() shape=(2500,) sharding=None ] 1.0 mn:f32[2500] = div mm 0.30000001192092896 mo:f32[2500] = neg mn mp:f32[2500] = mul 0.10000000149011612 mo mq:f32[2500] = abs mp mr:bool[2500] = le mq 9.999999747378752e-06 ms:f32[2500] = div mp 2.0 mt:f32[2500] = add 1.0 ms mu:f32[2500] = mul mp mp mv:f32[2500] = div mu 6.0 mw:f32[2500] = add mt mv mx:f32[2500] = exp mp my:f32[2500] = sub mx 1.0 mz:f32[2500] = div my mp na:f32[2500] = select_n mr mz mw nb:f32[2500] = mul 0.10000000149011612 na nc:f32[2500] = mul nb ml nd:f32[2500] = add cp nc ne:f32[2500] = neg cq nf:f32[2500] = div ne 3.0 ng:f32[2500] = broadcast_in_dim[ broadcast_dimensions=() shape=(2500,) sharding=None ] 1.0 nh:f32[2500] = div ng 3.0 ni:f32[2500] = neg nh nj:f32[2500] = mul 0.10000000149011612 ni nk:f32[2500] = abs nj nl:bool[2500] = le nk 9.999999747378752e-06 nm:f32[2500] = div nj 2.0 nn:f32[2500] = add 1.0 nm no:f32[2500] = mul nj nj np:f32[2500] = div no 6.0 nq:f32[2500] = add nn np nr:f32[2500] = exp nj ns:f32[2500] = sub nr 1.0 nt:f32[2500] = div ns nj nu:f32[2500] = select_n nl nt nq nv:f32[2500] = mul 0.10000000149011612 nu nw:f32[2500] = mul nv nf nx:f32[2500] = add cq nw ny:f32[2500] = sub nx nd nz:f32[2500] = mul 0.37037035822868347 ny oa:f32[2500] = neg dn ob:f32[2500] = div oa 0.30000001192092896 oc:f32[2500] = broadcast_in_dim[ broadcast_dimensions=() shape=(2500,) sharding=None ] 1.0 od:f32[2500] = div oc 0.30000001192092896 oe:f32[2500] = neg od of:f32[2500] = mul 0.10000000149011612 oe og:f32[2500] = abs of oh:bool[2500] = le og 9.999999747378752e-06 oi:f32[2500] = div of 2.0 oj:f32[2500] = add 1.0 oi ok:f32[2500] = mul of of ol:f32[2500] = div ok 6.0 om:f32[2500] = add oj ol on:f32[2500] = exp of oo:f32[2500] = sub on 1.0 op:f32[2500] = div oo of oq:f32[2500] = select_n oh op om or:f32[2500] = mul 0.10000000149011612 oq os:f32[2500] = mul or ob ot:f32[2500] = add dn os ou:f32[2500] = neg dp ov:f32[2500] = div ou 2.0 ow:f32[2500] = broadcast_in_dim[ broadcast_dimensions=() shape=(2500,) sharding=None ] 1.0 ox:f32[2500] = div ow 2.0 oy:f32[2500] = neg ox oz:f32[2500] = mul 0.10000000149011612 oy pa:f32[2500] = abs oz pb:bool[2500] = le pa 9.999999747378752e-06 pc:f32[2500] = div oz 2.0 pd:f32[2500] = add 1.0 pc pe:f32[2500] = mul oz oz pf:f32[2500] = div pe 6.0 pg:f32[2500] = add pd pf ph:f32[2500] = exp oz pi:f32[2500] = sub ph 1.0 pj:f32[2500] = div pi oz pk:f32[2500] = select_n pb pj pg pl:f32[2500] = mul 0.10000000149011612 pk pm:f32[2500] = mul pl ov pn:f32[2500] = add dp pm po:f32[2500] = sub pn ot pp:f32[2500] = mul 0.5882353186607361 po pq:f32[2500] = sub 0.0 bb pr:f32[2500] = mul mj pq ps:f32[2500] = add pr 0.0 pt:f32[2500] = sub -80.0 bb pu:f32[2500] = mul nz pt pv:f32[2500] = add ps pu pw:f32[2500] = sub 0.0 bb px:f32[2500] = mul cu pw py:f32[2500] = add pv px pz:f32[2500] = sub -80.0 bb qa:f32[2500] = mul cy pz qb:f32[2500] = add py qa qc:f32[2500] = add qb dd qd:f32[2500] = sub 0.0 bb qe:f32[2500] = mul pp qd qf:f32[2500] = add qc qe qg:f32[2500] = sub bb -60.625 qh:f32[2500] = div qg 6.5625 qi:f32[2500] = exp qh qj:f32[2500] = mul 6.5625 qi qk:f32[2500] = sub bb -70.0 ql:f32[2500] = neg qk qm:f32[2500] = add ql qj qn:f32[2500] = mul 1.0 qf qo:f32[2500] = add qm qn qp:f32[2500] = div qo 20.0 qq:f32[2500] = broadcast_in_dim[ broadcast_dimensions=() shape=(2500,) sharding=None ] 1.0 qr:f32[2500] = div qq 20.0 qs:f32[2500] = mul 6.5625 qr qt:f32[2500] = mul qs qi qu:f32[2500] = div qt 6.5625 qv:f32[2500] = neg qr qw:f32[2500] = add_any qu qv qx:f32[2500] = mul 0.10000000149011612 qw qy:f32[2500] = abs qx qz:bool[2500] = le qy 9.999999747378752e-06 ra:f32[2500] = div qx 2.0 rb:f32[2500] = add 1.0 ra rc:f32[2500] = mul qx qx rd:f32[2500] = div rc 6.0 re:f32[2500] = add rb rd rf:f32[2500] = exp qx rg:f32[2500] = sub rf 1.0 rh:f32[2500] = div rg qx ri:f32[2500] = select_n qz rh re rj:f32[2500] = mul 0.10000000149011612 ri rk:f32[2500] = mul rj qp rl:f32[2500] = add bb rk rm:f32[2500] = add rl 0.0 rn:f32[] = convert_element_type[new_dtype=float32 weak_type=False] ce ro:f32[2500] = sub rn bd rp:bool[2500] = le ro 5.0 rq:f32[2500] = pjit[ name=_where jaxpr={ lambda ; rr:bool[2500] rs:f32[2500] rt:f32[2500]. let ru:f32[2500] = select_n rr rt rs in (ru,) } ] rp bb rm rv:bool[2500] = ge rq -40.0 rw:f32[2500] = pjit[name=_where jaxpr=_where1] rv -70.0 rq rx:f32[2500] = pjit[name=_where jaxpr=_where1] rv ce bd ry:f32[10000] = sub 0.0 ks rz:f32[10000] = mul ff ry sa:f32[10000] = add rz 0.0 sb:f32[10000] = sub 0.0 ks sc:f32[10000] = mul cs sb sd:f32[10000] = add sa sc se:f32[10000] = sub -80.0 ks sf:f32[10000] = mul gv se sg:f32[10000] = add sf 0.0 sh:f32[10000] = sub -80.0 ks si:f32[10000] = mul cw sh sj:f32[10000] = add sg si sk:f32[2500] = sub 0.0 rw sl:f32[2500] = mul mj sk sm:f32[2500] = add sl 0.0 sn:f32[2500] = sub 0.0 rw so:f32[2500] = mul cu sn sp:f32[2500] = add sm so sq:f32[2500] = sub -80.0 rw sr:f32[2500] = mul nz sq ss:f32[2500] = add sr 0.0 st:f32[2500] = sub -80.0 rw su:f32[2500] = mul cy st sv:f32[2500] = add ss su sw:f32[10000] = add 0.0 db sx:f32[10000] = sub 0.0 ks sy:f32[10000] = mul il sx sz:f32[10000] = add sw sy ta:f32[2500] = add 0.0 dd tb:f32[2500] = sub 0.0 rw tc:f32[2500] = mul pp tb td:f32[2500] = add ta tc te:f32[10000] = sub fd ej tf:f32[10000] = mul te 0.5882353186607361 tg:f32[2500] = sub mh ln th:f32[2500] = mul tg 0.5882353186607361 ti:f32[10000] = sub gt fz tj:f32[10000] = mul ti 0.37037035822868347 tk:f32[2500] = sub nx nd tl:f32[2500] = mul tk 0.37037035822868347 debug_callback[ callback=<function debug_callback.<locals>._flat_callback at 0x000001A9A3D22FC0> effect=Debug ] in (dd, cu, rw, kr, rx, ks, rv, kt, ln, cs, db, cw, bl, bm, cy, ej, fd, bq, mh, fz, gt, nd, nx, bw, hp, ij, ot, pn, tf, th, sd, sj, ks, sz, kr, sp, tj, tl, sv, rw, td, rv) }, ()) During handling of the above exception, another exception occurred: Traceback (most recent call last): File "C:\Users\Vanilla\Desktop\03-Spatial_Embedded_SNN\code\simulate\run.py", line 722, in <module> snn_analyzer = SNN_analyzer(grid_num, grid_distance, E_conn_grid_num, I_conn_grid_num, wE, wI, sigmaE, E_inp_func, I_inp_func, time_period) File "C:\Users\Vanilla\Desktop\03-Spatial_Embedded_SNN\code\simulate\run.py", line 615, in __init__ self.runner.predict(duration=time_period) File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\runners.py", line 484, in predict outputs, hists = self._predict(indices, *inputs, shared_args=shared_args) File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\runners.py", line 538, in _predict outs_and_mons = self._fun_predict(indices, *xs, shared_args=shared_args) File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\runners.py", line 666, in _fun_predict return bm.for_loop(self._step_func_predict, File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\math\object_transform\controls.py", line 891, in for_loop dyn_vals, out_vals = transform(operands) File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\math\object_transform\controls.py", line 736, in call return jax.lax.scan(f=fun2scan, File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\math\object_transform\controls.py", line 727, in fun2scan results = body_fun(*x, **unroll_kwargs) File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\runners.py", line 627, in _step_func_predict out = self.target(*x) File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 421, in __call__ ret = self.update(*args, **kwargs) File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 370, in _compatible_update return update_fun(*args, **kwargs) File "C:\Users\Vanilla\Desktop\03-Spatial_Embedded_SNN\code\simulate\run.py", line 433, in update self.E2E() File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 421, in __call__ ret = self.update(*args, **kwargs) File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 370, in _compatible_update return update_fun(*args, **kwargs) File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 605, in update node.update(*args, **kwargs) File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 370, in _compatible_update return update_fun(*args, **kwargs) File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dyn\projections\align_post.py", line 273, in update current = self.comm(x) File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 421, in __call__ ret = self.update(*args, **kwargs) File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 370, in _compatible_update return update_fun(*args, **kwargs) File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dnn\linear.py", line 714, in update return bm.event.csrmv(self.weight, self.indices, self.indptr, x, File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\math\event\csr_matvec.py", line 68, in csrmv return bti.event_csrmv(data, indices, indptr, events, shape=shape, transpose=transpose) File "C:\Program Files\Python311\Lib\site-packages\braintaichi\_eventop\main.py", line 134, in event_csrmv return event_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)[0] File "C:\Program Files\Python311\Lib\site-packages\braintaichi\_eventop\csrmv.py", line 91, in event_csrmv_taichi return prim( File "C:\Program Files\Python311\Lib\site-packages\braintaichi\_primitive\_xla_custom_op.py", line 116, in __call__ return self.primitive.bind(*ins, outs=outs, **kwargs) jax._src.source_info_util.JaxStackTraceBeforeTransformation: RuntimeError: The CPU kernels do not build correctly. Please check the installation of braintaichi. The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception. -------------------- The above exception was the direct cause of the following exception: Traceback (most recent call last): File "C:\Users\Vanilla\Desktop\03-Spatial_Embedded_SNN\code\simulate\run.py", line 722, in <module> snn_analyzer = SNN_analyzer(grid_num, grid_distance, E_conn_grid_num, I_conn_grid_num, wE, wI, sigmaE, E_inp_func, I_inp_func, time_period) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Vanilla\Desktop\03-Spatial_Embedded_SNN\code\simulate\run.py", line 615, in __init__ self.runner.predict(duration=time_period) File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\runners.py", line 484, in predict outputs, hists = self._predict(indices, *inputs, shared_args=shared_args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\runners.py", line 538, in _predict outs_and_mons = self._fun_predict(indices, *xs, shared_args=shared_args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\runners.py", line 666, in _fun_predict return bm.for_loop(self._step_func_predict, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\math\object_transform\controls.py", line 891, in for_loop dyn_vals, out_vals = transform(operands) ^^^^^^^^^^^^^^^^^^^ File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\math\object_transform\controls.py", line 736, in call return jax.lax.scan(f=fun2scan, ^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Program Files\Python311\Lib\site-packages\braintaichi\_primitive\_mlir_translation_rule.py", line 441, in _taichi_mlir_cpu_translation_rule raise RuntimeError( RuntimeError: The CPU kernels do not build correctly. Please check the installation of braintaichi. -------------------- For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Packages
Package Version ------------------------- ------------------ albumentations 1.3.1 altgraph 0.17.4 asttokens 2.4.1 brainpy 2.6.0.post20241205 brainstate 0.1.0.post20241210 braintaichi 0.0.3 brainunit 0.0.3.post20241211 certifi 2023.7.22 charset-normalizer 3.3.0 colorama 0.4.6 colorlog 6.8.2 comm 0.1.4 contourpy 1.1.1 cycler 0.12.1 debugpy 1.8.0 decorator 5.1.1 dictdiffer 0.9.0 dill 0.3.9 einops 0.7.0 entmax 1.1 et-xmlfile 1.1.0 executing 2.0.1 filelock 3.12.4 fonttools 4.43.1 fsspec 2023.9.2 huggingface-hub 0.17.3 idna 3.4 imageio 2.31.5 ipykernel 6.26.0 ipython 8.17.1 ipywidgets 8.1.5 jax 0.4.37 jaxlib 0.4.36 jedi 0.19.1 Jinja2 3.1.2 joblib 1.3.2 jupyter_client 8.5.0 jupyter_core 5.5.0 jupyterlab_widgets 3.0.13 kiwisolver 1.4.5 lazy_loader 0.3 llvmlite 0.43.0 lxml 4.9.3 markdown-it-py 3.0.0 MarkupSafe 2.1.3 matplotlib 3.8.0 matplotlib-inline 0.1.6 mdurl 0.1.2 ml_dtypes 0.5.0 mpmath 1.3.0 munch 4.0.0 nest-asyncio 1.5.8 networkx 3.1 numba 0.60.0 numpy 1.24.4 opencv-python-headless 4.10.0.84 openpyxl 3.1.2 opt_einsum 3.4.0 packaging 24.0 pandas 2.1.1 parse 1.20.2 parso 0.8.3 pefile 2023.2.7 Pillow 10.0.1 pip 24.3.1 pix2tex 0.1.2 platformdirs 3.11.0 prompt-toolkit 3.0.39 psutil 5.9.8 pure-eval 0.2.2 pycocotools 2.0.8 pycryptodome 3.20.0 Pygments 2.18.0 pyinstaller 6.6.0 pyinstaller-hooks-contrib 2024.6 Pymem 1.13.1 pynput 1.7.6 pyparsing 3.1.1 PyQt6 6.7.0 PyQt6-Qt6 6.7.0 PyQt6-sip 13.6.0 PyQt6-WebEngine 6.5.0 PyQt6-WebEngine-Qt6 6.5.3 pyreadline3 3.4.1 PySide6 6.5.3 PySide6-Addons 6.5.3 PySide6-Essentials 6.5.3 pystache 0.6.5 python-dateutil 2.8.2 pytz 2023.3.post1 pywin32 306 pywin32-ctypes 0.2.2 PyYAML 6.0.2 pyzmq 25.1.1 qudida 0.0.4 regex 2023.10.3 requests 2.31.0 resolvelib 1.0.1 rich 13.9.4 ruamel.yaml 0.18.6 ruamel.yaml.clib 0.2.8 safetensors 0.4.0 scikit-image 0.22.0 scikit-learn 1.3.1 scipy 1.14.1 screeninfo 0.8.1 setuptools 65.5.0 shiboken6 6.5.3 six 1.16.0 stack-data 0.6.3 sympy 1.12 taichi 1.7.2 threadpoolctl 3.2.0 tifffile 2023.9.26 timm 0.5.4 tokenizers 0.14.1 torch 2.1.0 torchaudio 2.1.0 torchvision 0.16.0 tornado 6.3.3 tqdm 4.67.1 traitlets 5.13.0 transformers 4.34.0 typing_extensions 4.12.2 tzdata 2023.3 urllib3 2.0.6 watchdog 4.0.0 wcwidth 0.2.9 widgetsnbextension 4.0.13 x-transformers 0.15.0
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working