DOMANDA Valutazione grafica dei modelli di regressione

Pubblicità

Giulio95

Nuovo Utente
Messaggi
37
Reazioni
4
Punteggio
26
Salve,

vorrei chiedere se effettuando una regressione ci sono alcune "tecniche" per capire se il modello risultante e' sufficientemente esplicativo.

Chiaramente ci sono le varie metriche da interpretare.

Tuttavia, finche' si fa una regressione con vettori X di lunghezza massima 2, si puo' ottimizzare la valutazione ricorrendo ad un grafico per verificare visivamente la qualita' del modello.

Se aggiungessimo ulteriori variabili al vettore X le dimensioni sarebbero troppe per un grafico a dispersione.
In questo caso esistono pratiche particolari che si possono seguire ?

Riporto il mio ultimo esercizio di regressione su TensorFlow.js.
2 variabili X e un etichetta.
Che si visualizzano in un grafico 3D alla fine.

Grazie mille e saluti.

HTML:
<html>
  <head>
    <script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/axios@latest/dist/axios.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis@latest/dist/tfjs-vis.umd.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest/dist/tf.min.js"></script>
    <link rel="stylesheet" href="https://cdn.datatables.net/2.0.5/css/dataTables.dataTables.css" />
    <script src="https://cdn.datatables.net/2.0.5/js/dataTables.js"></script>
    <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
  </head>
  <body>
    <h4>Regression Model<hr/></h4>
    <div id="micro-out-div" style="width: 60%;"></div>
    <div id="plot" style="width: 60%; height: 60%;"></div>
    <script>
                    
        axios.get('https://storage.googleapis.com/tfjs-tutorials/carsData.json').then(response => {
            const filteredData = response.data.filter(car => {
                if (car.Miles_per_Gallon != null && car.Horsepower != null) return true;
            });
            const tensors = mapData(filteredData);
            runModel(tensors);
        });

        // Create a simple model.
        const model = tf.sequential();
      
        model.add(tf.layers.dense({
            units: 120,
            activation: 'relu',
            inputShape: [2]
        }));
      
        model.add(tf.layers.dense({
            units: 64,
            activation: 'relu'
        }));   
      
        model.add(tf.layers.dense({
            units: 32,
            activation: 'relu'
        }));       
      
        model.add(tf.layers.dense({
            units: 1,
            activation: 'linear'
        }));
        
        model.compile({
            optimizer: tf.train.adam(),
            loss: tf.losses.huberLoss,
            metrics: ['mape'],
        });
        
        tfvis.show.modelSummary(
            {name: 'Model info'},
            model
        );
                        
        const mapData = (data) => {
            
            return tf.tidy(() => {

                const X = data.map(car => {
                    return [
                        car.Miles_per_Gallon,
                        car.Horsepower
                    ];
                });
              
                const Y = data.map(car => {
                    return car.Weight_in_lbs
                });

                const xs = tf.tensor2d(X, [X.length, 2]);
                const ys = tf.tensor1d(Y);
                              
                const xsmax = xs.max();
                const xsmin = xs.min();
                const ysmax = ys.max();
                const ysmin = ys.min();

                const nxs = xs.sub(xsmin).div(xsmax.sub(xsmin));
                const nys = ys.sub(ysmin).div(ysmax.sub(ysmin));
                                
                return {
                    xs,
                    ys,
                    nxs,
                    nys,
                    xsmax,
                    xsmin,
                    ysmax,
                    ysmin,
                    X,
                    Y
                };
                
            });
            
        }
      
        const runModel = ({
            xs,
            ys,
            nxs,
            nys,
            xsmax,
            xsmin,
            ysmax,
            ysmin,
            X,
            Y
        }) => {
            
            model.fit(nxs, nys, {
                batchSize: 164,
                epochs: 100,
                callbacks: tfvis.show.fitCallbacks(
                    {name: 'Train Performance'},
                    ['loss','mape'],
                    {height: 200, callbacks: ['onEpochEnd']}
                )
            }).then(() => {
                
                //Effettuo test delle predizioni sottomenttendo al modello gli stessi vettori X di addestramento
                const nprediction = model.predict(nxs);
                const prediction = nprediction.mul(ysmax.sub(ysmin)).add(ysmin);
                const predictionCollection = Array.from(prediction.dataSync());
                
                //associo in un unico oggetto le coppie di addestramento [X1, X2], le etichette Y, le predizioni P, la differenza tra predizioni ed etichette P - Y
                const mappedPrediction = X.map((couple, index) => {
                    return {
                        X: couple,
                        Y: Y[index],
                        P: predictionCollection[index],
                        D: predictionCollection[index] - Y[index]
                    };
                });
                
                //riporto in tabella mappedPrediction
                printTableWithPredictions(mappedPrediction);
                
                //Per il plot 3d separo le grandezze del vettore 2d X in due vettori 1d e lancio un grafico 3d per confrontare etichette e predizioni
                const XSeries = X.map(couple => { return couple[0]; });
                const YSeries = X.map(couple => { return couple[1]; });
                plot3d(XSeries, YSeries, Y, predictionCollection);
                
                //model.save('downloads://regression-model-1');
                
            });
        }
        
        const printTableWithPredictions = (mappedPrediction) => {
            const htSpace = document.querySelector('#micro-out-div');
            
            const table = document.createElement('table');
            table.id = 'dTable';
            table.classList.add('hover');
            table.classList.add('stripe');
            
            const thead = document.createElement('thead');
            
            const headerRow = document.createElement('tr');
            
            const headerX = document.createElement('th');
            headerX.innerText = 'X';
            
            const headerY = document.createElement('th');
            headerY.innerText = 'Y';

            const headerP = document.createElement('th');
            headerP.innerText = 'P';
            
            const headerD = document.createElement('th');
            headerD.innerText = 'P - Y';
            
            headerRow.appendChild(headerX);
            headerRow.appendChild(headerY);
            headerRow.appendChild(headerP);
            headerRow.appendChild(headerD);
            
            thead.appendChild(headerRow);
            table.appendChild(thead);
            
            const tbody = document.createElement('tbody');
            
            mappedPrediction.forEach(row => {
                const htRow = document.createElement('tr');
                
                const XData = document.createElement('td');
                XData.innerText = `${row.X[0]}; ${row.X[1]}`;
                
                const YData = document.createElement('td');
                YData.innerText = row.Y;
    
                const PData = document.createElement('td');
                PData.innerText = row.P;
                
                const DData = document.createElement('td');
                DData.innerText = row.D;
                
                htRow.appendChild(XData);
                htRow.appendChild(YData);
                htRow.appendChild(PData);
                htRow.appendChild(DData);
                                    
                tbody.appendChild(htRow);
            });
            
            table.appendChild(tbody);
            htSpace.appendChild(table);
            
            const dTable = new DataTable('#dTable', {});
        };
        
        const plot3d = (XSeries, YSeries, ZSeries, PSeries) => {
            const data1 = {
                x: XSeries,
                y: YSeries,
                z: ZSeries,
                type: 'scatter3d',
                mode: 'markers',
                marker: {
                    size: 2,
                    color: 'rgb(100,150,200)'
                }
            };
            
            const data2 = {
                x: XSeries,
                y: YSeries,
                z: PSeries,
                type: 'scatter3d',
                mode: 'markers',
                marker: {
                    size: 2,
                    color: 'rgb(255,0,0)'
                }
            };
            
            const data = [data1, data2];

            const layout = {
                scene: {
                    xaxis: { title: 'MPG' },
                    yaxis: { title: 'H' },
                    zaxis: { title: 'WIL' }
                }
            };

            Plotly.newPlot('plot', data, layout);
        };
                
    </script>
  </body>
</html>
 
L' interpretabilità di un modello è intrinsecamente legata al modello stesso. Ci sono ovviamente delle tecniche che potrebbero a seconda del caso e dei dati permettere di comprendere sommariamente il comportamento del modello.
Ad esempio potresti valutare di utilizzare la pca/t-sne per evidenziare quali sono le feature maggiormente discriminative.
Un approccio molto semplice da realizzare, ma altrettanto efficace è l'analisi della sensitività al variare dell'input.
Esiste anche la shap decisamente più completa ed elaborata.

Vien da se che utilizzare un multilayer perceptron non è il massimo dell'interpretabilità.
 
Ultima modifica:
Ciao, grazie mille per la risposta, approfondiro' le tecniche che mi hai citato.

Per pytorch ti capisco perfettamente XD.

Dovevo decidere con che libreria iniziare, e sono stato attratto dal fatto che potevo lanciare tutto nel browser.
Mi sembrava piu' adatto al neofita. Inoltre all'occorrenza genero facilmente componenti grafici come nell'esempio da me caricato.

Pero' sto parlando senza conoscere le potenzialita' degli strumenti Python.

Di nuovo grazie e saluti.
 
Pubblicità
Pubblicità
Indietro
Top