improve multiprocessing

This commit is contained in:
2024-08-31 22:24:46 +02:00
parent 6ea763f914
commit f50a36c4d0

View File

@@ -1,3 +1,4 @@
import os
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass
@@ -86,41 +87,46 @@ def run_simulation(n_servers=2, lam=40):
}
with ProcessPoolExecutor(max_workers=4) as executor:
n_servers = [n + 1 for n in range(10)]
lam = [(n + 1) * 4 for n in range(10)]
parameter_values = [
{"n_servers": n, "lam": m} for n in n_servers for m in lam for _ in range(10)
]
futures = [
executor.submit(run_simulation, **parameters) for parameters in parameter_values
]
if __name__ == "__main__":
with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
n_servers = [n + 1 for n in range(10)]
lam = [(n + 1) * 4 for n in range(10)]
parameter_values = [
{"n_servers": n, "lam": m}
for n in n_servers
for m in lam
for _ in range(100)
]
futures = [
executor.submit(run_simulation, **parameters)
for parameters in parameter_values
]
results = []
for future in tqdm(as_completed(futures), total=len(parameter_values)):
result = future.result()
results.append(result)
results = []
for future in tqdm(as_completed(futures), total=len(parameter_values)):
result = future.result()
results.append(result)
df = (
pl.DataFrame(results)
.fill_null(0)
.group_by(["n_servers", "lam"])
.agg(pl.all().mean())
.sort(["n_servers", "lam"])
)
df_queue_time = df.pivot("lam", index="n_servers", values="queue_time").sort(
"n_servers"
)
def stats(column):
return (
df.pivot("lam", index="n_servers", values=column)
.sort("n_servers")
.with_columns(pl.all().round(2))
df = (
pl.DataFrame(results)
.fill_null(0)
.group_by(["n_servers", "lam"])
.agg(pl.all().mean())
.sort(["n_servers", "lam"])
)
pl.Config.set_tbl_cols(20)
print(df)
print(stats("queue_time"))
print(stats("utilization_server_1"))
df_queue_time = df.pivot("lam", index="n_servers", values="queue_time").sort(
"n_servers"
)
def stats(column):
return (
df.pivot("lam", index="n_servers", values=column)
.sort("n_servers")
.with_columns(pl.all().round(2))
)
pl.Config.set_tbl_cols(20)
print(df)
print(stats("queue_time"))
print(stats("utilization_server_1"))