Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# dependencies
/node_modules/

# model.json files
server/models
cli/models
# stored trained models
models/

# tsc built
dist/
Expand Down
2 changes: 0 additions & 2 deletions cli/.gitignore

This file was deleted.

32 changes: 30 additions & 2 deletions cli/src/args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ interface BenchmarkArguments {
epochs: number
roundDuration: number
batchSize: number
validationSplit: number
epsilon?: number
delta?: number
dpDefaultClippingRadius?: number
save: boolean
host: URL
}
Expand All @@ -28,6 +32,10 @@ const unsafeArgs = parse<BenchmarkUnsafeArguments>(
epochs: { type: Number, alias: 'e', description: 'Number of epochs', defaultValue: 10 },
roundDuration: { type: Number, alias: 'r', description: 'Round duration (in epochs)', defaultValue: 2 },
batchSize: { type: Number, alias: 'b', description: 'Training batch size', defaultValue: 10 },
validationSplit : { type: Number, alias: 'v', description: 'Validation dataset ratio', defaultValue: 0.2 },
epsilon: { type: Number, alias: 'n', description: 'Privacy budget', optional: true, defaultValue: undefined},
delta: { type: Number, alias: 'd', description: 'Probability of failure, slack parameter', optional: true, defaultValue: undefined},
dpDefaultClippingRadius: {type: Number, alias: 'f', description: 'Default clipping radius for DP', optional: true, defaultValue: undefined},
save: { type: Boolean, alias: 's', description: 'Save logs of benchmark', defaultValue: false },
host: {
type: (raw: string) => new URL(raw),
Expand All @@ -52,6 +60,7 @@ const supportedTasks = Map(
defaultTasks.simpleFace,
defaultTasks.titanic,
defaultTasks.tinderDog,
defaultTasks.mnist,
).map(
async (t) =>
[(await t.getTask()).id, t] as [
Expand All @@ -77,10 +86,29 @@ export const args: BenchmarkArguments = {
task.trainingInformation.batchSize = unsafeArgs.batchSize;
task.trainingInformation.roundDuration = unsafeArgs.roundDuration;
task.trainingInformation.epochs = unsafeArgs.epochs;
task.trainingInformation.validationSplit = unsafeArgs.validationSplit;

// For DP
// TASK.trainingInformation.clippingRadius = 10000000
// TASK.trainingInformation.noiseScale = 0
const {dpDefaultClippingRadius, epsilon, delta} = unsafeArgs;

if (
// dpDefaultClippingRadius !== undefined &&
epsilon !== undefined &&
delta !== undefined
){
if (task.trainingInformation.scheme === "local")
throw new Error("Can't have differential privacy for local training");

const defaultRadius = dpDefaultClippingRadius ? dpDefaultClippingRadius : 1;

// for the case where privacy parameters are not defined in the default tasks
task.trainingInformation.privacy ??= {}
task.trainingInformation.privacy.differentialPrivacy = {
clippingRadius: defaultRadius,
epsilon: epsilon,
delta: delta,
};
}

return task;
},
Expand Down
2 changes: 1 addition & 1 deletion cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async function main<D extends DataType, N extends Network>(
console.log({ args })

const dataSplits = await Promise.all(
Range(0, numberOfUsers).map(async i => getTaskData(task.id, i))
Range(0, numberOfUsers).map(async i => getTaskData(task.id, i, numberOfUsers))
)
const logs = await Promise.all(
dataSplits.map(async data => await runUser(task, args.host, data as Dataset<DataFormat.Raw[D]>))
Expand Down
89 changes: 80 additions & 9 deletions cli/src/data.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import path from "node:path";
import { Dataset, processing } from "@epfml/discojs";
import type {
import { promises as fs } from "fs";
import { Dataset, processing, defaultTasks } from "@epfml/discojs";
import {
DataFormat,
DataType,
Image,
Expand All @@ -9,18 +10,20 @@ import type {
import { loadCSV, loadImage, loadImagesInDir } from "@epfml/discojs-node";
import { Repeat } from "immutable";

async function loadSimpleFaceData(): Promise<Dataset<DataFormat.Raw["image"]>> {
async function loadSimpleFaceData(userIdx: number, totalClient: number): Promise<Dataset<DataFormat.Raw["image"]>> {
const folder = path.join("..", "datasets", "simple_face");

const [adults, childs]: Dataset<[Image, string]>[] = [
(await loadImagesInDir(path.join(folder, "adult"))).zip(Repeat("adult")),
(await loadImagesInDir(path.join(folder, "child"))).zip(Repeat("child")),
];

return adults.chain(childs);
const combinded = adults.chain(childs);

return combinded.filter((_, i) => i % totalClient === userIdx);
}

async function loadLusCovidData(): Promise<Dataset<DataFormat.Raw["image"]>> {
async function loadLusCovidData(userIdx: number, totalClient: number): Promise<Dataset<DataFormat.Raw["image"]>> {
const folder = path.join("..", "datasets", "lus_covid");

const [positive, negative]: Dataset<[Image, string]>[] = [
Expand All @@ -32,7 +35,11 @@ async function loadLusCovidData(): Promise<Dataset<DataFormat.Raw["image"]>> {
),
];

return positive.chain(negative);
const combined: Dataset<[Image, string]> = positive.chain(negative);

const sharded = combined.filter((_, i) => i % totalClient === userIdx);

return sharded;
}

function loadTinderDogData(split: number): Dataset<DataFormat.Raw["image"]> {
Expand All @@ -59,25 +66,89 @@ function loadTinderDogData(split: number): Dataset<DataFormat.Raw["image"]> {
});
}

async function loadExtCifar10(userIdx: number): Promise<Dataset<[Image, string]>> {
const CIFAR10_LABELS = Array.from(await defaultTasks.cifar10.getTask().then(t => t.trainingInformation.LABEL_LIST));
const folder = path.join("..", "datasets", "extended_cifar10");
const clientFolder = path.join(folder, `client_${userIdx}`);

return new Dataset(async function*(){
const entries = await fs.readdir(clientFolder, {withFileTypes: true});

const items = entries
.flatMap((e) => {
const m = e.name.match(
/^image_(\d+)_label_(\d+)\.png$/i
);
if (m === null) return [];
const labelIdx = Number.parseInt(m[2], 10);

if(labelIdx >= CIFAR10_LABELS.length)
throw new Error(`${e.name}: too big label index`);

return {
name: e.name,
label: CIFAR10_LABELS[labelIdx],
};
})
.filter((x) => x !== null)

for (const {name, label} of items){
const filePath = path.join(clientFolder, name);
const image = await loadImage(filePath);
yield [image, label] as const;
}
})
}

function loadMnistData(split: number): Dataset<DataFormat.Raw["image"]>{
const folder = path.join("..", "datasets", "mnist", `${split + 1}`);
return loadCSV(path.join(folder, "labels.csv"))
.map(
(row) =>
[
processing.extractColumn(row, "filename"),
processing.extractColumn(row, "label"),
] as const,
)
.map(async ([filename, label]) => {
try {
const image = await Promise.any(
["png", "jpg", "jpeg"].map((ext) =>
loadImage(path.join(folder, `${filename}.${ext}`)),
),
);
return [image, label];
} catch {
throw Error(`${filename} not found in ${folder}`);
}
});
}

export async function getTaskData<D extends DataType>(
taskID: Task.ID,
userIdx: number,
totalClient: number
): Promise<Dataset<DataFormat.Raw[D]>> {
switch (taskID) {
case "simple_face":
return (await loadSimpleFaceData()) as Dataset<DataFormat.Raw[D]>;
return (await loadSimpleFaceData(userIdx, totalClient)) as Dataset<DataFormat.Raw[D]>;
case "titanic":
return loadCSV(
const titanicData = loadCSV(
path.join("..", "datasets", "titanic_train.csv"),
) as Dataset<DataFormat.Raw[D]>;
return titanicData.filter((_, i) => i % totalClient === userIdx);
case "cifar10":
return (
await loadImagesInDir(path.join("..", "datasets", "CIFAR10"))
).zip(Repeat("cat")) as Dataset<DataFormat.Raw[D]>;
case "lus_covid":
return (await loadLusCovidData()) as Dataset<DataFormat.Raw[D]>;
return (await loadLusCovidData(userIdx, totalClient)) as Dataset<DataFormat.Raw[D]>;
case "tinder_dog":
return loadTinderDogData(userIdx) as Dataset<DataFormat.Raw[D]>;
case "extended_cifar10":
return (await loadExtCifar10(userIdx)) as Dataset<DataFormat.Raw[D]>;
case "mnist":
return loadMnistData(userIdx) as Dataset<DataFormat.Raw[D]>;
default:
throw new Error(`Data loader for ${taskID} not implemented.`);
}
Expand Down
27 changes: 15 additions & 12 deletions discojs/src/aggregator/get.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,22 @@ export function getAggregator(
};

switch (task.trainingInformation.aggregationStrategy) {
case 'byzantine': {
const {byzantineClippingRadius = 1.0, maxIterations = 1, beta = 0.9,
} = task.trainingInformation;
case "byzantine": {
const {
clippingRadius = 1.0,
maxIterations = 1,
beta = 0.9,
} = task.trainingInformation.privacy.byzantineFaultTolerance;

return new ByzantineRobustAggregator(
networkOptions.roundCutOff,
networkOptions.threshold,
networkOptions.thresholdType,
byzantineClippingRadius,
maxIterations,
beta
);
}
return new ByzantineRobustAggregator(
networkOptions.roundCutOff,
networkOptions.threshold,
networkOptions.thresholdType,
clippingRadius,
maxIterations,
beta,
);
}
case 'mean':
return new aggregator.MeanAggregator(
networkOptions.roundCutOff,
Expand Down
54 changes: 54 additions & 0 deletions discojs/src/dataset/dataset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,60 @@ export class Dataset<T> implements AsyncIterable<T> {
cached(): Dataset<T> {
return new CachingDataset(this.#content);
}

/** Shuffles the Dataset instance within certain window size */
shuffle(windowSize: number){
if (!Number.isInteger(windowSize) || windowSize < 1){
throw new Error("Shuffle window size should be a positive integer");
}

return new Dataset(
async function*(this: Dataset<T>){
const iter = this[Symbol.asyncIterator]();
const buffer: T[] = [];

// 1. Construct the initial buffer
while (buffer.length < windowSize){
const n = await iter.next();
if (n.done) break;
buffer.push(n.value);
}

// 2. Shuffle
while (buffer.length > 0){
const pick = Math.floor(Math.random() * buffer.length);
const chosen = buffer[pick];

const n = await iter.next();

if (n.done){
// move the last element to the pick position
buffer[pick] = buffer.pop() as T;
}else{
buffer[pick] = n.value;
}

yield chosen;
}
}.bind(this)
);
}

/** filter the indices according to the splitting condition */
filter(
condition: (value: T, index: number) => boolean | Promise<boolean>
): Dataset<T>{
return new Dataset<T>(async function* (this: Dataset<T>): AsyncGenerator<T, void, unknown>{
let i = 0;
for await(const v of this){
if (await condition(v, i)){
yield v;
}
i += 1
}
}.bind(this));
}

}

/**
Expand Down
15 changes: 10 additions & 5 deletions discojs/src/default_tasks/cifar10.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,24 @@ export const cifar10: TaskProvider<"image", "decentralized"> = {
},
},
trainingInformation: {
epochs: 10,
epochs: 20,
roundDuration: 10,
validationSplit: 0.2,
batchSize: 10,
IMAGE_H: 224,
IMAGE_W: 224,
LABEL_LIST: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
scheme: 'decentralized',
aggregationStrategy: 'byzantine',
byzantineClippingRadius: 10.0,
maxIterations: 1,
beta: 0.9,
privacy: { clippingRadius: 20, noiseScale: 1 },
aggregationStrategy: 'mean',
privacy: {
differentialPrivacy: {
clippingRadius: 1,
epsilon: 50,
delta: 1e-5,
},
},
minNbOfParticipants: 3,
maxShareValue: 100,
tensorBackend: 'tfjs'
Expand All @@ -66,7 +71,7 @@ export const cifar10: TaskProvider<"image", "decentralized"> = {
model.compile({
optimizer: 'sgd',
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
metrics: ['accuracy'],
})

return new models.TFJS('image', model)
Expand Down
9 changes: 8 additions & 1 deletion discojs/src/default_tasks/mnist.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,14 @@ export const mnist: TaskProvider<"image", "decentralized"> = {
IMAGE_W: 28,
LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
scheme: 'decentralized',
aggregationStrategy: 'mean',
aggregationStrategy: "byzantine",
privacy: {
byzantineFaultTolerance: {
clippingRadius: 10,
maxIterations: 1,
beta: 0.9,
},
},
minNbOfParticipants: 3,
maxShareValue: 100,
tensorBackend: 'tfjs'
Expand Down
Loading
Loading